tl;dr
We introduce an initialization for MLP networks that works for any activation function, thus generalizing Xavier
1
and Kaiming initialization
2
(which were derived for Tanh and ReLU respectively).
This is important for implicit neural representations (INRs) as the literature uses (and keeps proposing) activation functions that are very different to classic activation functions.
SIREN
4,
which uses sine activations and provides an initialization for them, shows the importance of initialization tailored to the activation function, and we build heavily off them.
We show that our initialization does a better job compared to previous initializations (we can maintain steady variance in both the forward and backward pass instead of just one), and show significant task
performance improvement in INR tasks.
Our Initialization
Consider a general MLP network with activation function \(f\):
A standard approach to initializing networks is to ensure the variance of the preactivations at each layer is the same, and the variance of the gradients of loss w.r.t. the
preactivations at each layer is the same.
We first show that the distribution of preactivations at layer \(i\) (i.e. the elements of \(z_i\)), converges in distribution to
\[ \mathcal{N}\left(0, M_i\left(\mu^2(x_i)+\sigma^2(x_i)\right)\sigma^2(W_i)\right) \]
(note that this was first shown by Kumar (2017) 3, however we provide a more rigourous proof by generalizing Sitzmann et al.'s proof for SIREN networks
4).
This means that the variance of the preactivations at layer \(i\) (\(z_i\)) depends on the variance of the weight at that layer (\(W_i\)) and the distribution
of the input at that layer (\(x_i\)). The key to our approach is to set the distribution of the preactivations at each layer to be \(\mathcal{N}\left(0, \sigma_p^2\right)\)
where we choose \(\sigma_p\). To do this, we initialize the weights at layer \(i\) to have variance
\[ \sigma^2(W_i) = \frac{\sigma_p^2}{M_i\left(\mu^2(x_i)+\sigma^2(x_i)\right)} \]
where the statistics of the output of the previous layer can be computed using the fact that the preactivations in that layer have been set to have distribution
\(\mathcal{N}\left(0, \sigma_p^2\right)\):
\[ \mu(x_i) = \mu(f(z_{i-1})) = \mathbf{E}_{z\sim \mathcal{N}(0,\sigma^2_p)}\left[f(z)\right] \]
\[ \sigma^2(x_i) = \sigma^2(f(z_{i-1})) = \mathbf{Var}_{z\sim \mathcal{N}(0,\sigma^2_p)}\left[f(z)\right]. \]
Calculating these statistics analytically is difficult depending on \(f\).
We use Monte Carlo sampling, which we show is efficient and accurate unlike previous approaches.
Similarly, we derive the condition for the backward pass:
\[ \sigma^2(W_i) = \frac{1}{M_{i+1}\left(\mu^2(f'(z_i)) + \sigma^2(f'(z_i))\right)}. \]
Unlike previous approaches (e.g. Xavier, Kaiming) we can make the condition for the forward and backward pass both hold due to \(\sigma_p\) being a free parameter for us to set.
Thus we need to find \(\sigma_p\) that satisfies
\[ \sigma_p^2 \frac{M_{i+1}}{M_i} \frac{\mu^2(f'(z_i)) + \sigma^2(f'(z_i))}{\mu^2(x_i)+\sigma^2(x_i)} = 1. \]
This is non-trivial to do analytically as many of the terms are expectations over the distribution \(\mathcal{N}\left(0, \sigma_p^2\right)\), we instead perform a fast grid search
to find the \(\sigma_p\) that makes the left hand side as close to 1 as possible.
Deriving Xavier and Kaiming init
Our init is a more general form than Xavier and Kaiming init, so we can easily derive them from our init.
Note that when \(\sigma_p=1\) and the asumptions for Xavier init hold (\(f(x)\approx x \implies \mu(x_i)=0,\sigma^2(x_i)=1\)), our conditions reduce to Xavier init's conditions
\[ \sigma^2(W_i) = \frac{1}{M_i} \text{ (Forward Pass)}\]
\[ \sigma^2(W_i) = \frac{1}{M_{i+1}} \text{ (Backward Pass)} \]
and when \(\sigma_p=1\) and we have ReLU then we get Kaiming init (see paper for explanation)
\[ \sigma^2(W_i) = \frac{2}{M_i} \text{ (Forward Pass)}\]
\[ \sigma^2(W_i) = \frac{2}{M_{i+1}} \text{ (Backward Pass)}\]
As both initialization cannot satisfy both conditions at the same time like ours, Xavier init takes the average of their conditions, and Kaiming suggests to use either.
Comparison to PyTorch, and proper "gain"
PyTorch 5 generalizes Xavier and Kaiming init by introducing "gain", so for forward pass Kaiming (they call this fan-in Kaiming)
\[ \sigma^2(W_i) = \text{gain}^2(f)\frac{1}{M_i} \]
While never explictly defined, it is motivated as a scaling term on the weight's variance to compensate for the activation function, and is often implied to be
\[ \text{gain}^2(f) \approx \frac{\sigma^2(z_i)}{\sigma^2(f(z_i))} \]
which is not well defined. As a result, it is often treated as a hyperparameter to brute force search for a value that keeps the variance through the network stable.
Our formulation makes it well defined due to \(\sigma_p\) (the fact that we are setting the preactivation variance of the current layer given that it has been set for the previous layer),
thus for the forward pass
\[ \text{gain}^2(f, \sigma_p) = \frac{\sigma_p^2}{\mathbf{E}_{z\sim \mathcal{N}(0,\sigma^2_p)}\left[f(z)\right]^2+\mathbf{Var}_{z\sim \mathcal{N}(0,\sigma^2_p)}\left[f(z)\right]}. \]