Fisher Flow: Optimization on the Statistical Manifold

April 20, 2024

Standard gradient descent treats parameter space as flat. It uses Euclidean distance, which means the same step size in parameter space can produce wildly different changes in the distribution depending on where you are. Fisher Flow fixes this by optimizing along the natural geometry of probability distributions.

The Geometry of Distributions

Probability distributions form a Riemannian manifold. The Fisher information matrix provides the natural metric on this manifold:

FIM: I(theta) = E[grad log p(x|theta) grad log p(x|theta)^T]

This captures how sensitive the distribution is to parameter changes. It is the curvature of the likelihood surface. It tells you the true gradient direction in parameter space, not the Euclidean one.

Natural Gradient vs. Standard Gradient

Standard gradient descent updates parameters in Euclidean space:

theta_{t+1} = theta_t - alpha grad L(theta_t)

This is inefficient because it ignores parameter correlations. A step of size epsilon in one direction might barely change the distribution while the same step in another direction changes it drastically.

Natural gradient descent uses the Fisher metric:

theta_{t+1} = theta_t - alpha I(theta_t)^{-1} grad L(theta_t)

Pre-multiplying by the inverse Fisher information rescales the gradient to account for the local geometry. The key property: natural gradient is invariant to reparametrization. It does not matter how you choose to represent your distributions.

Fisher Flow

Fisher Flow makes this continuous:

dtheta/dt = -I(theta)^{-1} grad L(theta)

This defines a flow on the parameter manifold. Loss decreases monotonically along the flow. The trajectories follow the natural geometry. Step size adapts automatically because the curvature scales the updates.

The Practical Problem

Computing I(theta)^{-1} is expensive. For a neural network with n parameters, the full Fisher information matrix is n x n, and inverting it is O(n^3). That is not practical for modern networks.

Approximations help. Diagonal Fisher approximation is cheap but crude. Block-diagonal Fisher captures within-layer correlations. K-FAC (Kronecker-Factored Approximate Curvature) approximates the Fisher with Kronecker products, bringing computation down to O(n). That makes it practical for real networks.

def fisher_flow_step(params, loss_fn, data):
    # Compute gradient
    grad = compute_gradient(loss_fn, params, data)

    # Estimate Fisher information (diagonal approx)
    fisher = estimate_fisher_diagonal(params, data)

    # Natural gradient step
    natural_grad = grad / (fisher + epsilon)

    # Update parameters
    params = params - learning_rate * natural_grad

    return params

Variational Inference

Fisher flow gives a natural framework for variational inference. You have a variational distribution q(z|phi) and you want to minimize KL[q(z|phi) || p(z|x)]. Following the Fisher geometry of q means your optimization respects the structure of the distribution family you are searching over.

Read More