Projects

Wristband Gaussian Loss: From O(N2)O(N^2) to O(N)O(N)

Story

Mikhail Parakhin (CTO of Shopify) published the Wristband Gaussian Loss. A training regularizer that forces a neural encoder's outputs to be exactly distributed as a standard Gaussian. Available on Github.

Work

This is a continuation from: Wristband Gaussian Loss: Formalization and Proof. The original work used pairwise computations for the repulsion term, to measure true uniformity. By rearranging equations I was able to compute the repulsion term via spectral decomposition of the Wristband Space. The original equation separated the computation in two parts, the angular term and the radial term.

K((u,t),(u,t))=kang(u,u)krad(t,t),K\bigl((u,t),(u',t')\bigr) = k_\text{ang}(u, u') \cdot k_\text{rad}(t, t'),

where uu, tt are points on the wristband space. With the kernel energy to be minimized being,

E(P)=EW,WiidP[K(W,W)].\mathcal{E}(P) = \mathbb{E}_{W, W' \stackrel{\text{iid}}{\sim} P}\bigl[K(W, W')\bigr].

In my work I re-wrote the kernel computation, 'merging' the angular and radial kernels:

Esp=λ0k=0Kak(1Nicoskπti) ⁣2+λ1m=1dk=0Kak(dNiuimcoskπti) ⁣2,\boxed{\mathcal{E}\text{sp} = \lambda_0 \sum{k=0}^{K} \tilde{a}k \left(\frac{1}{N}\sum_i \cos k\pi t_i\right)^{!2} + \lambda_1 \sum{m=1}^{d} \sum_{k=0}^{K} \tilde{a}k \left(\frac{\sqrt{d}}{N}\sum_i u{im} \cos k\pi t_i\right)^{!2}},

where dd is the latent space dimension and KK is the truncation term for the spectral approximation.

This means going from a O(N2)O(N^2) pairwise energy computation to O(Nd)O(Nd), much faster even for small NN. This work was also proven with Lean, making sure that this computation was also minimal at the uniform distribution (the main purpose of the kernel).

To-do: expand the full theoretical derivation

Notes

All theoretical ideas for the Wristband Space and the original implementation are Mikhail Parakhin's.
My contribution here is in improving the kernel computation step.

Visit project →
Meta
TopicsMathematics