699 字
3 分钟
Score-Matching Langevin Dynamics (SMLD)

1 Langevin Dynamics#

NOTE

How to sample from a distribution The Langevin dynamics for sampling from a known distribution p(x)p(\mathbf{x}) is an iterative procedure for t=1,,T:t=1,\ldots,T{:}

xt+1=xt+τxlogp(xt)+2τz,zN(0,I),\mathbf{x}_{t+1}=\mathbf{x}_t+\tau\nabla_\mathbf{x}\log p(\mathbf{x}_t)+\sqrt{2\tau}\mathbf{z},\quad\mathbf{z}\sim\mathcal{N}(0,\mathbf{I}),

where τ\tau is the step size which users can control, and x0\mathbf{x}_0 is white noise.

  • Without the noise term, Langevin dynamics is gradient descent.

The intuition is that if we want to sample x\mathbf{x} from a distribution, certainly the “optimal” location for x\mathbf{x} is where p(x)p(\mathbf{x}) is maximized (seeing the peak as a representation of a distribution). So the goal of sampling is equivalent to solving the optimization

x=argmaxxlogp(x).\mathbf{x}^*=\underset{\mathbf{x}}{\operatorname*{argmax}} \log p(\mathbf{x}).

distribution <=> peak <=> optimal solution

WARNING

Langevin dynamics is stochastic gradient descent.

  • We do stochastic gradient descent since we want to sample from a distribution, instead of solving the optimization problem.

2 (Stein’s) Score Function#

The second component of the Langevin dynamics equation has a formal name known as the Stein’s score function, denoted by

sθ(x)=defxlogpθ(x).\mathrm{s}_{\boldsymbol{\theta}}(\mathrm{x})\overset{\mathrm{def}}{\operatorname*{=}}\nabla_{\mathbf{x}}\log p_{\boldsymbol{\theta}}(\mathbf{x}).

The way to understand the score function is to remember that it is the gradient with respect to the data x.\mathbf{x}. For any high-dimensional distribution p(x)p(\mathbf{x}), the gradient will give us vector field xlogp(x)=a vector field=[logp(x)x,logp(x)y]T\nabla_\mathbf{x}\log p(\mathbf{x})=\text{a vector field}=\begin{bmatrix}\frac{\partial\log p(\mathbf{x})}{\partial x},&\frac{\partial\log p(\mathbf{x})}{\partial y}\end{bmatrix}^T

Geometric Interpretations of the Score Function:

  • The magnitude of the vectors are the strongest at places where the change of logp(x)\log p(\mathbf{x}) is the biggest. Therefore, in regions where logp(x)\log p(\mathbf{x}) is close to the peak will be mostly very weak gradient.
  • The vector field indicates how a data point should travel in the contour map.
  • In physics, the score function is equivalent to the “drift”. This name suggests how the diffusion particles should flow to the lowest energy state.

3 Score Matching Techniques#

Note that since the distribution is not known, we need some methods to approximate it.

Explicit Score-Matching#

Suppose that we are given a dataset X={x1,,xM}.\mathcal{X}=\{\mathbf{x}_1,\ldots,\mathbf{x}_M\}. The solution people came up with is to consider the classical kernel density estimation by defining a distribution

q(x)=1Mm=1M1hK(xxmh),q(\mathbf{x})=\frac1M\sum_{m=1}^M\frac1hK\left(\frac{\mathbf{x}-\mathbf{x}_m}h\right),

where hh is just some hyperparameter for the kernel function K()K(\cdot), and xm\mathbf{x}_m is the mm-th sample in the training set.

  • The sum of all these individual kernels gives us the overall kernel density estimate q(x).q(\mathbf{x}).
  • The idea is similar to Gaussian Mixture Model.

Since q(x)q(\mathbf{x}) is an approximation to p(x)p(\mathbf{x}) which is never accessible, we can learn sθ(x)_{\boldsymbol{\theta}}(\mathbf{x}) based on q(x).q(\mathbf{x}). This leads to the following definition of the a loss function which can be used to train a network.

WARNING

The explicit score matching loss is

JESM(θ)=defEq(x)sθ(x)xlogq(x)2J_{\mathrm{ESM}}(\boldsymbol{\theta})\stackrel{\mathrm{def}}{=}\mathbb{E}_{q(\mathbf{x})}\|\mathrm{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}\log q(\mathbf{x})\|^2

By substituting the kernel density estimation, we can show that the loss is

JESM(θ)=defEq(x)sθ(x)xlogq(x)2=sθ(x)xlogq(x)2[1Mm=1M1hK(xxmh)]dx=1Mm=1Msθ(x)xlogq(x)21hK(xxmh)dx.\begin{aligned} J_{\mathrm{ESM}}(\theta )&\stackrel{\mathrm{def}}{=}\mathbb{E}_{q(\mathbf{x})}\|\mathrm{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}\log q(\mathbf{x})\|^2 \\ &=\int\|\mathrm{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}\log q(\mathbf{x})\|^2\left[\frac1M\sum_{m=1}^M\frac1hK\left(\frac{\mathbf{x}-\mathbf{x}_m}h\right)\right]d\mathbf{x} \\ &=\frac1M\sum_{m=1}^M\int\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}\log q(\mathbf{x})\|^2\frac1hK\left(\frac{\mathbf{x}-\mathbf{x}_m}h\right)d\mathbf{x}. \end{aligned}

Once we train the network sθ_{\boldsymbol{\theta}}, we can replace it in the Langevin dynamics equation to obtain the recursion:

xt+1=xt+τsθ(xt)+2τz.\mathbf{x}_{t+1}=\mathbf{x}_t+\tau\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_t)+\sqrt{2\tau}\mathbf{z}.

Issue:

  • The kernel density estimation is a fairly poor non-parameter estimation of the true distribution.
  • When we have limited number of samples and the samples live in a high dimensional space, the kernel density estimation performance can be poor.

Denoising Score Matching#

In DSM, the loss function is defined as follows.

JDSM(θ)=defEq(x,x)[12sθ(x)xq(xx)2]J_{\mathrm{DSM}}(\boldsymbol{\theta})\stackrel{\mathrm{def}}{=}\mathbb{E}_{q(\mathbf{x},\mathbf{x}^{\prime})}\left[\frac12\left\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}q(\mathbf{x}|\mathbf{x}^{\prime})\right\|^2\right]

The idea comes from the Denoising Autoencoder approach of using pairs of clean and corrupted examples (x,x)(\mathbf{x},\mathbf{x}'). In the generative model, x\mathbf{x}' can be seen as the xt\mathbf{x}_{t}.

The conditional distribution q(xx)q(\mathbf{x}|\mathbf{x}^{\prime}) does not require an approximation. In the special case where q(xx)=N(xx,σ2)q(\mathbf{x}|\mathbf{x}^{\prime})=\mathcal{N}(\mathbf{x}\mid\mathbf{x}^{\prime},\sigma^2), we can let x=x+σz.\mathbf{x}=\mathbf{x}^\prime+\sigma\mathbf{z}. This will give us

xlogq(xx)=xlog1(2πσ2)dexp{xx22σ2}=x{xx22σ2log(2πσ2)d}=xxσ2=zσ2.\begin{aligned} \nabla_{\mathbf{x}}\log q(\mathbf{x}|\mathbf{x}^{\prime})&=\nabla_{\mathbf{x}}\log\frac1{(\sqrt{2\pi\sigma^2})^d}\exp\left\{-\frac{\|\mathbf{x}-\mathbf{x}^{\prime}\|^2}{2\sigma^2}\right\}\\&=\nabla_\mathbf{x}\left\{-\frac{\|\mathbf{x}-\mathbf{x}^{\prime}\|^2}{2\sigma^2}-\log(\sqrt{2\pi\sigma^2})^d\right\}\\&=-\frac{\mathbf{x}-\mathbf{x}^{\prime}}{\sigma^2}=-\frac{\mathbf{z}}{\sigma^2}. \end{aligned}

As a result, the loss function of the denoising score matching becomes

JDSM(θ)=defEq(x,x)[12sθ(x)xq(xx)2]=Eq(x)[12sθ(x+σz)+zσ22].\begin{aligned}J_{\mathrm{DSM}}(\boldsymbol{\theta})&\stackrel{\mathrm{def}}{=}\mathbb{E}_{q(\mathbf{x},\mathbf{x}^{\prime})}\left[\frac12\left\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}q(\mathbf{x}|\mathbf{x}^{\prime})\right\|^2\right]\\&=\mathbb{E}_{q(\mathbf{x}^{\prime})}\left[\frac12\left\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}^{\prime}+\sigma\mathbf{z})+\frac{\mathbf{z}}{\sigma^2}\right\|^2\right].\end{aligned}

The gradient operation cancels x\mathbf{x}.

WARNING

The Denoising Score Matching has a loss function defined as

JDSM(θ)=Ep(x)[12sθ(x+σz)+zσ22]J_{\mathrm{DSM}}(\boldsymbol{\theta})=\mathbb{E}_{p(\mathbf{x})}\left[\frac12\left\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}+\sigma\mathbf{z})+\frac{\mathbf{z}}{\sigma^2}\right\|^2\right]
  • The quantity x+σz\mathbf{x}+\sigma\mathbf{z} is effectively adding noise σz\sigma\textbf{z} to a clean image x\mathbf{x}.
  • The score function sθs_{{\boldsymbol{\theta}}} is supposed to take this noisy image and predict the noise zσ2.\frac {\mathbf{z} }{\sigma ^2}.
  • Predicting noise is equivalent to denoising, because any denoised image plus the predicted noise will give us the noisy observation. SMLD

The training step can simply described as follows: You give us a training dataset {x()}=1L\{\mathbf{x}^{(\ell)}\}_{\ell=1}^L, we train a network θ\boldsymbol{\theta} with the goal to

θ=argminθ1L=1L12sθ(x()+σz())+z()σ22,wherez()N(0,I).\theta^*=\underset{\boldsymbol{\theta}}{\operatorname*{argmin}}\quad\frac1L\sum_{\ell=1}^L\frac12\left\|\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}^{(\ell)}+\sigma\mathbf{z}^{(\ell)}\right)+\frac{\mathbf{z}^{(\ell)}}{\sigma^2}\right\|^2,\quad\mathrm{where}\quad\mathbf{z}^{(\ell)}\sim\mathcal{N}(0,\mathbf{I}).

The last thing is that why the loss function of DSM makes sense?

WARNING

Theorem For up to a constant CC which is independent of the variable θ\boldsymbol{\theta}, it holds that

JDSM(θ)=JESM(θ)+C.J_{\mathrm{DSM}}(\theta)=J_{\mathrm{ESM}}(\theta)+C.

The proof is not hard.

For inference, we assume that we have already trained the score estimator sθ._{\boldsymbol{\theta}}. To generate an image, we perform the following procedure for t=1,,T:t=1,\ldots,T{:}

xt+1=xt+τsθ(xt)+2τzt,whereztN(0,I).\mathbf{x}_{t+1}=\mathbf{x}_t+\tau\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_t)+\sqrt{2\tau}\mathbf{z}_t,\quad\mathrm{where}\quad\mathbf{z}_t\sim\mathcal{N}(0,\mathbf{I}).

Reference#

[1] Chan, Stanley H. “Tutorial on Diffusion Models for Imaging and Vision.” arXiv preprint arXiv:2403.18103 (2024).

Score-Matching Langevin Dynamics (SMLD)
https://fuwari.vercel.app/posts/score-matching-langevin-dynamics/
作者
pride7
发布于
2024-08-25
许可协议
CC BY-NC-SA 4.0