539 字
3 分钟
Stochastic Differential Equation (SDE)

1 Forward and Backward Iterations in SDE#

The basic idea comes from gradient descent algorithm.

xi=xi1τf(xi1)+zi1x(t+Δt)=x(t)τf(x(t))+z(t).\begin{align} \mathbf{x}_i=\mathbf{x}_{i-1}-\tau'\nabla f(\mathbf{x}_{i-1})+\mathbf{z}_{i-1}\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t). \end{align}

Now, let’ s define a random process w(t)\mathbf{w} ( t) such that z(t)=w(t+Δt)w(t)dw(t)dtΔt( t) = \mathbf{w} ( t+ \Delta t) - \mathbf{w} ( t) \approx \frac {d\mathbf{w} ( t) }{dt}\Delta t for a very small Δt.\Delta t. In computation, we can generate such a w(t)\mathbf{w}(t) by integrating z(t)(t) (which is a Wiener process). With w(t)(t) defined, we can write

x(t+Δt)=x(t)τf(x(t))+z(t)x(t+Δt)x(t)=τf(x(t))+w(t+Δt)w(t)dx=τf(x)dt+dw.\begin{align} \mathbf{x}(t+\Delta t)&=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t)\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\tau\nabla f(\mathbf{x}(t))+\mathbf{w}(t+\Delta t)-\mathbf{w}(t)\\ \Longrightarrow\quad \quad \quad \quad \quad \quad \quad d\mathbf{x}&=-\tau\nabla f(\mathbf{x})dt+d\mathbf{w}. \end{align}

Note that we often use dw=Δtz(t)d\mathbf{w} =\sqrt{ \Delta t} \mathbf{z}(t), which is different from this equation.

WARNING

Forward Diffusion

dx=f(x,t)driftdt+g(t)diffusiondw.d\mathbf{x}=\underbrace{\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}\:dt+\underbrace{g(t)}_{\mathrm{diffusion}}\:d\mathbf{w}.
WARNING

Reverse SDE

dx=[f(x,t)driftg(t)2xlogpt(x)]score functiondt+g(t)dwreverse-time diffusion,d\mathbf{x}=\underbrace{[\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}-g(t)^2\underbrace{\nabla_\mathbf{x}\log p_t(\mathbf{x})]}_{\text{score function}}\:dt\quad+\underbrace{g(t)d\overline{\mathbf{w}}}_{\text{reverse-time diffusion}},

where pt(x)p_t(\mathbf{x}) is the probability distribution of x\mathbf{x} at time tt, and w\overline{\mathbf{w}} is the Wiener process when time flows backward.

2 Stochastic Differential Equation for DDPM#

We consider the discrete-time DDPM iteration. For i=1,2,,Ni=1,2,\dots,N:

xi=1βixi1+βizi1,zi1N(0,I).\begin{aligned}\mathbf{x}_i&=\sqrt{1-\beta_i}\mathbf{x}_{i-1}+\sqrt{\beta_i}\mathbf{z}_{i-1},\quad&\mathbf{z}_{i-1}\sim\mathcal{N}(0,\mathbf{I}).\end{aligned}
WARNING

The forward sampling equation of DDPM can be written as an SDE via

dx=β(t)2 x=f(x,t)dt+β(t)=g(t)dw.d\mathbf{x}=\underbrace{-\frac{\beta(t)}2\textbf{ x}}_{=\mathbf{f}(\mathbf{x},t)}dt+\underbrace{\sqrt{\beta(t)}}_{=g(t)}d\mathbf{w}.

Note that here w\mathbf{w} is the Wiener process.

  • DDPM iteration itself is solving the SDE. (a first order method)
WARNING

The reverse sampling equation of DDPM can be written as an SDE via

dx=β(t)[x2+xlogpt(x)]dt+β(t)dw.d\mathbf{x}=-\beta(t)\left[\frac{\mathbf{x}}{2}+\nabla_{\mathbf{x}}\log p_t(\mathbf{x})\right]dt+\sqrt{\beta(t)}d\overline{\mathbf{w}}.
x(t)x(tΔt)=β(t)Δtx(t)2+xlogpt(x(t))β(t)Δtz(t)    x(tΔt)=x(t)+β(t)Δt[x(t)2+xlogpt(x(t))]+β(t)Δtz(t).\begin{aligned} &\mathbf{x}(t)-\mathbf{x}(t-\Delta t)=-\beta(t)\Delta t\left\lfloor\frac{\mathbf{x}(t)}{2}+\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right\rfloor-\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &\implies\quad\mathbf{x}(t-\Delta t)=\mathbf{x}(t)+\beta(t)\Delta t\left[\frac{\mathbf{x}(t)}{2}+\nabla_{\mathbf{x}}\operatorname{log}p_t(\mathbf{x}(t))\right]+\sqrt{\beta(t)\Delta t}\mathbf{z}(t). \\ \end{aligned}

By grouping the terms, and assuming that β(t)Δt1\beta(t)\Delta t\ll1, we recognize that

x(tΔt)=x(t)[1+β(t)Δt2]+β(t)Δtxlogpt(x(t))+β(t)Δtz(t)x(t)[1+β(t)Δt2]+β(t)Δtxlogpt(x(t))+(β(t)Δt)22xlogpt(x(t))+β(t)Δtz(t)=[1+β(t)Δt2](x(t)+β(t)Δtxlogpt(x(t)))+β(t)Δtz(t)\begin{aligned} \mathbf{x}(t-\Delta t)& \begin{aligned}&=\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}{2}\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t)\end{aligned} \\ &\approx\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}2\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\frac{(\beta(t)\Delta t)^2}2\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &=\left[1+\frac{\beta(t)\Delta t}{2}\right]\left(\mathbf{x}(t)+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \end{aligned}

Then, following the discretization scheme, we can show that

xi1=(1+βi2)[xi+βi2xlogpi(xi)]+βizi11βi[xi+βi2xlogpi(xi)]+βizi,\begin{aligned}\mathbf{x}_{i-1}&=(1+\frac{\beta_i}2)\bigg[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\bigg]+\sqrt{\beta_i}\mathbf{z}_i\\&\approx\frac1{\sqrt{1-\beta_i}}\left[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\right]+\sqrt{\beta_i}\mathbf{z}_i,\end{aligned}

where pi(x)p_i(\mathbf{x}) is the probability density function of x\mathbf{x} at time ii. For practical implementation, we can replace xlogpi(xi)\nabla_\mathbf{x}\log p_i(\mathbf{x}_i) by the estimated score function sθ(xi)\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_i).

3 Stochastic Differential Equation for SMLD#

Although there isn’t a forward diffusion step, if we divide the noise scale (e.g., x+σz\mathbf{x}+\sigma \mathbf{z}) in the SMLD training into NN levels, then the recursion should follow a Markov chain

xi=xi1+σi2σi12zi1,i=1,2,,N.\mathbf{x}_i=\mathbf{x}_{i-1}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}\mathbf{z}_{i-1},\quad i=1,2,\ldots,N.

If we assume that the variance of xi1\mathbf{x}_{i-1} is σi12\sigma_{i-1}^2, then we can show that

Var[xi]=Var[xi1]+(σi2σi12)=σi12+(σi2σi12)=σi2.\begin{aligned} \mathrm{Var}[\mathbf{x}_{i}]&=\mathrm{Var}[\mathbf{x}_{i-1}]+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\\&=\sigma_{i-1}^{2}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})=\sigma_{i}^{2}. \end{aligned}

Therefore, given a sequence of noise levels, above equation will indeed generate estimates xi\mathbf{x}_i such that the noise statistics will satisfy the desired property.

Assuming that in the limit{σi}i=1N\operatorname*{limit}\left\{\sigma_i\right\}_{i=1}^Nbecomes the continuous time σ(t)\sigma(t) for 0t10\leq t\leq1, and {xi}i=1N\{\mathbf{x}_i\}_i=1^N becomes x(t)\mathbf{x}(t) where xi=x(iN)\mathbf{x}_i=\mathbf{x}(\frac iN) if we let t{0,1N,,N1N}.t\in\{0,\frac1N,\ldots,\frac{N-1}N\}. Then we have

x(t+Δt)=x(t)+σ(t+Δt)2σ(t)2z(t)x(t)+d[σ(t)2]dtΔtz(t).\begin{aligned}\mathbf{x}(t+\Delta t)&=\mathbf{x}(t)+\sqrt{\sigma(t+\Delta t)^2-\sigma(t)^2}\mathbf{z}(t)\\&\approx\mathbf{x}(t)+\sqrt{\frac{d[\sigma(t)^2]}{dt}\Delta t}\:\mathbf{z}(t).\end{aligned}

At the limit when Δt0\Delta t\to0, the equation converges to

dx=d[σ(t)2]dtdw.d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}.
NOTE

The forward sampling equation of SMLD can be written as an SDE via

dx=d[σ(t)2]dtdw.d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}.

NOTE

The reverse sampling equation of SMLD can be written as an SDE via

dx=(d[σ(t)2]dtxlogpt(x(t)))dt+d[σ(t)2]dtdw.d\mathbf{x}=-\left(\frac{d[\sigma(t)^2]}{dt}\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)dt+\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\overline{\mathbf{w}}.

For the discrete-time iterations, we first define α(t)=d[σ(t)2]dt.\alpha(t)=\frac{d[\sigma(t)^2]}{dt}. Then, using the same set of discretization setups as the DDPM case, we can show that

x(t+Δt)x(t)=(α(t)xlogpt(x))Δtα(t)Δtz(t)x(t)=x(t+Δt)+α(t)Δtxlogpt(x)+α(t)Δtz(t)xi1=xi+αixlogpi(xi)+αizixi1=xi+(σi2σi12)xlogpi(xi)+(σi2σi12)zi,\begin{aligned} \mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\Big(\alpha(t)\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\Big)\Delta t-\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}(t)&=\mathbf{x}(t+\Delta t)+\alpha(t)\Delta t\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})+\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+\alpha_{i}\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{\alpha_{i}}\:\mathbf{z}_{i}\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{(\sigma_{i}^{2}-\sigma_{i-1}^{2})}\:\mathbf{z}_{i}, \end{aligned}

which is identical to the SMLD reverse update equation.

4 Solving SDE#

Predictor-Corrector Algorithm: If we have already trained the score function sθ(xi,i)s_{\boldsymbol{\theta}}(\mathbf{x}_{i}, i), we can run the score-matching equation. For example, in the NN time steps (reverse process), we can run MM times score-matching equation to make the correction.

Accelerate the SDE Solver:

NOTE

Theorem [Variation of Constants]. Consider the ODE over the range [s,t]:[s,t]:

dx(t)dt=a(t)x(t)+b(t),wherex(t0)=x0.\frac{dx(t)}{dt}=a(t)x(t)+b(t),\quad\mathrm{where}\:x(t_0)=x_0.

The solution is given by

x(t)=x0eA(t)+eA(t)t0teA(τ)b(τ)dτ.x(t)=x_0e^{A(t)}+e^{A(t)}\int_{t_0}^te^{-A(\tau)}b(\tau)d\tau.

where A(t)=t0ta(τ)dτ.A(t)=\int_{t_{0}}^{t}a(\tau)d\tau.

Stochastic Differential Equation (SDE)
https://fuwari.vercel.app/posts/stochastic-differential-equation-sde/
作者
pride7
发布于
2024-08-27
许可协议
CC BY-NC-SA 4.0