1 Forward and Backward Iterations in SDE# The basic idea comes from gradient descent algorithm.
x i = x i − 1 − τ ′ ∇ f ( x i − 1 ) + z i − 1 ⟹ x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t ) ) + z ( t ) . \begin{align} \mathbf{x}_i=\mathbf{x}_{i-1}-\tau'\nabla f(\mathbf{x}_{i-1})+\mathbf{z}_{i-1}\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t). \end{align} x i = x i − 1 − τ ′ ∇ f ( x i − 1 ) + z i − 1 ⟹ x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t )) + z ( t ) . Now, let’ s define a random process w ( t ) \mathbf{w} ( t) w ( t ) such that z( t ) = w ( t + Δ t ) − w ( t ) ≈ d w ( t ) d t Δ t ( t) = \mathbf{w} ( t+ \Delta t) - \mathbf{w} ( t) \approx \frac {d\mathbf{w} ( t) }{dt}\Delta t ( t ) = w ( t + Δ t ) − w ( t ) ≈ d t d w ( t ) Δ t for a very small Δ t . \Delta t. Δ t . In computation, we can generate such a w ( t ) \mathbf{w}(t) w ( t ) by integrating z( t ) (t) ( t ) (which is a Wiener process). With w( t ) (t) ( t ) defined, we can write
x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t ) ) + z ( t ) ⟹ x ( t + Δ t ) − x ( t ) = − τ ∇ f ( x ( t ) ) + w ( t + Δ t ) − w ( t ) ⟹ d x = − τ ∇ f ( x ) d t + d w . \begin{align} \mathbf{x}(t+\Delta t)&=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t)\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\tau\nabla f(\mathbf{x}(t))+\mathbf{w}(t+\Delta t)-\mathbf{w}(t)\\ \Longrightarrow\quad \quad \quad \quad \quad \quad \quad d\mathbf{x}&=-\tau\nabla f(\mathbf{x})dt+d\mathbf{w}. \end{align} x ( t + Δ t ) ⟹ x ( t + Δ t ) − x ( t ) ⟹ d x = x ( t ) − τ ∇ f ( x ( t )) + z ( t ) = − τ ∇ f ( x ( t )) + w ( t + Δ t ) − w ( t ) = − τ ∇ f ( x ) d t + d w . Note that we often use d w = Δ t z ( t ) d\mathbf{w} =\sqrt{ \Delta t} \mathbf{z}(t) d w = Δ t z ( t ) , which is different from this equation.
WARNING Forward Diffusion
d x = f ( x , t ) ⏟ d r i f t d t + g ( t ) ⏟ d i f f u s i o n d w . d\mathbf{x}=\underbrace{\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}\:dt+\underbrace{g(t)}_{\mathrm{diffusion}}\:d\mathbf{w}. d x = drift f ( x , t ) d t + diffusion g ( t ) d w . WARNING Reverse SDE
d x = [ f ( x , t ) ⏟ d r i f t − g ( t ) 2 ∇ x log p t ( x ) ] ⏟ score function d t + g ( t ) d w ‾ ⏟ reverse-time diffusion , d\mathbf{x}=\underbrace{[\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}-g(t)^2\underbrace{\nabla_\mathbf{x}\log p_t(\mathbf{x})]}_{\text{score function}}\:dt\quad+\underbrace{g(t)d\overline{\mathbf{w}}}_{\text{reverse-time diffusion}}, d x = drift [ f ( x , t ) − g ( t ) 2 score function ∇ x log p t ( x )] d t + reverse-time diffusion g ( t ) d w , where p t ( x ) p_t(\mathbf{x}) p t ( x ) is the probability distribution of x \mathbf{x} x at time t t t , and w ‾ \overline{\mathbf{w}} w is the Wiener process when time flows backward.
2 Stochastic Differential Equation for DDPM# We consider the discrete-time DDPM iteration. For i = 1 , 2 , … , N i=1,2,\dots,N i = 1 , 2 , … , N :
x i = 1 − β i x i − 1 + β i z i − 1 , z i − 1 ∼ N ( 0 , I ) . \begin{aligned}\mathbf{x}_i&=\sqrt{1-\beta_i}\mathbf{x}_{i-1}+\sqrt{\beta_i}\mathbf{z}_{i-1},\quad&\mathbf{z}_{i-1}\sim\mathcal{N}(0,\mathbf{I}).\end{aligned} x i = 1 − β i x i − 1 + β i z i − 1 , z i − 1 ∼ N ( 0 , I ) . WARNING The forward sampling equation of DDPM can be written as an SDE via
d x = − β ( t ) 2 x ⏟ = f ( x , t ) d t + β ( t ) ⏟ = g ( t ) d w . d\mathbf{x}=\underbrace{-\frac{\beta(t)}2\textbf{ x}}_{=\mathbf{f}(\mathbf{x},t)}dt+\underbrace{\sqrt{\beta(t)}}_{=g(t)}d\mathbf{w}. d x = = f ( x , t ) − 2 β ( t ) x d t + = g ( t ) β ( t ) d w . Note that here w \mathbf{w} w is the Wiener process.
DDPM iteration itself is solving the SDE. (a first order method) WARNING The reverse sampling equation of DDPM can be written as an SDE via
d x = − β ( t ) [ x 2 + ∇ x log p t ( x ) ] d t + β ( t ) d w ‾ . d\mathbf{x}=-\beta(t)\left[\frac{\mathbf{x}}{2}+\nabla_{\mathbf{x}}\log p_t(\mathbf{x})\right]dt+\sqrt{\beta(t)}d\overline{\mathbf{w}}. d x = − β ( t ) [ 2 x + ∇ x log p t ( x ) ] d t + β ( t ) d w . x ( t ) − x ( t − Δ t ) = − β ( t ) Δ t ⌊ x ( t ) 2 + ∇ x log p t ( x ( t ) ) ⌋ − β ( t ) Δ t z ( t ) ⟹ x ( t − Δ t ) = x ( t ) + β ( t ) Δ t [ x ( t ) 2 + ∇ x log p t ( x ( t ) ) ] + β ( t ) Δ t z ( t ) . \begin{aligned} &\mathbf{x}(t)-\mathbf{x}(t-\Delta t)=-\beta(t)\Delta t\left\lfloor\frac{\mathbf{x}(t)}{2}+\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right\rfloor-\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &\implies\quad\mathbf{x}(t-\Delta t)=\mathbf{x}(t)+\beta(t)\Delta t\left[\frac{\mathbf{x}(t)}{2}+\nabla_{\mathbf{x}}\operatorname{log}p_t(\mathbf{x}(t))\right]+\sqrt{\beta(t)\Delta t}\mathbf{z}(t). \\ \end{aligned} x ( t ) − x ( t − Δ t ) = − β ( t ) Δ t ⌊ 2 x ( t ) + ∇ x log p t ( x ( t )) ⌋ − β ( t ) Δ t z ( t ) ⟹ x ( t − Δ t ) = x ( t ) + β ( t ) Δ t [ 2 x ( t ) + ∇ x log p t ( x ( t )) ] + β ( t ) Δ t z ( t ) . By grouping the terms, and assuming that β ( t ) Δ t ≪ 1 \beta(t)\Delta t\ll1 β ( t ) Δ t ≪ 1 , we recognize that
x ( t − Δ t ) = x ( t ) [ 1 + β ( t ) Δ t 2 ] + β ( t ) Δ t ∇ x log p t ( x ( t ) ) + β ( t ) Δ t z ( t ) ≈ x ( t ) [ 1 + β ( t ) Δ t 2 ] + β ( t ) Δ t ∇ x log p t ( x ( t ) ) + ( β ( t ) Δ t ) 2 2 ∇ x log p t ( x ( t ) ) + β ( t ) Δ t z ( t ) = [ 1 + β ( t ) Δ t 2 ] ( x ( t ) + β ( t ) Δ t ∇ x log p t ( x ( t ) ) ) + β ( t ) Δ t z ( t ) \begin{aligned} \mathbf{x}(t-\Delta t)& \begin{aligned}&=\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}{2}\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t)\end{aligned} \\ &\approx\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}2\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\frac{(\beta(t)\Delta t)^2}2\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &=\left[1+\frac{\beta(t)\Delta t}{2}\right]\left(\mathbf{x}(t)+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \end{aligned} x ( t − Δ t ) = x ( t ) [ 1 + 2 β ( t ) Δ t ] + β ( t ) Δ t ∇ x log p t ( x ( t )) + β ( t ) Δ t z ( t ) ≈ x ( t ) [ 1 + 2 β ( t ) Δ t ] + β ( t ) Δ t ∇ x log p t ( x ( t )) + 2 ( β ( t ) Δ t ) 2 ∇ x log p t ( x ( t )) + β ( t ) Δ t z ( t ) = [ 1 + 2 β ( t ) Δ t ] ( x ( t ) + β ( t ) Δ t ∇ x log p t ( x ( t )) ) + β ( t ) Δ t z ( t ) Then, following the discretization scheme, we can show that
x i − 1 = ( 1 + β i 2 ) [ x i + β i 2 ∇ x log p i ( x i ) ] + β i z i ≈ 1 1 − β i [ x i + β i 2 ∇ x log p i ( x i ) ] + β i z i , \begin{aligned}\mathbf{x}_{i-1}&=(1+\frac{\beta_i}2)\bigg[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\bigg]+\sqrt{\beta_i}\mathbf{z}_i\\&\approx\frac1{\sqrt{1-\beta_i}}\left[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\right]+\sqrt{\beta_i}\mathbf{z}_i,\end{aligned} x i − 1 = ( 1 + 2 β i ) [ x i + 2 β i ∇ x log p i ( x i ) ] + β i z i ≈ 1 − β i 1 [ x i + 2 β i ∇ x log p i ( x i ) ] + β i z i , where p i ( x ) p_i(\mathbf{x}) p i ( x ) is the probability density function of x \mathbf{x} x at time i i i . For practical implementation, we can replace ∇ x log p i ( x i ) \nabla_\mathbf{x}\log p_i(\mathbf{x}_i) ∇ x log p i ( x i ) by the estimated score function s θ ( x i ) \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_i) s θ ( x i ) .
3 Stochastic Differential Equation for SMLD# Although there isn’t a forward diffusion step, if we divide the noise scale (e.g., x + σ z \mathbf{x}+\sigma \mathbf{z} x + σ z ) in the SMLD training into N N N levels, then the recursion should follow a Markov chain
x i = x i − 1 + σ i 2 − σ i − 1 2 z i − 1 , i = 1 , 2 , … , N . \mathbf{x}_i=\mathbf{x}_{i-1}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}\mathbf{z}_{i-1},\quad i=1,2,\ldots,N. x i = x i − 1 + σ i 2 − σ i − 1 2 z i − 1 , i = 1 , 2 , … , N . If we assume that the variance of x i − 1 \mathbf{x}_{i-1} x i − 1 is σ i − 1 2 \sigma_{i-1}^2 σ i − 1 2 , then we can show that
V a r [ x i ] = V a r [ x i − 1 ] + ( σ i 2 − σ i − 1 2 ) = σ i − 1 2 + ( σ i 2 − σ i − 1 2 ) = σ i 2 . \begin{aligned} \mathrm{Var}[\mathbf{x}_{i}]&=\mathrm{Var}[\mathbf{x}_{i-1}]+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\\&=\sigma_{i-1}^{2}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})=\sigma_{i}^{2}. \end{aligned} Var [ x i ] = Var [ x i − 1 ] + ( σ i 2 − σ i − 1 2 ) = σ i − 1 2 + ( σ i 2 − σ i − 1 2 ) = σ i 2 . Therefore, given a sequence of noise levels, above equation will indeed generate estimates x i \mathbf{x}_i x i such that the noise statistics will satisfy the desired property.
Assuming that in the limit { σ i } i = 1 N \operatorname*{limit}\left\{\sigma_i\right\}_{i=1}^N limit { σ i } i = 1 N becomes the continuous time σ ( t ) \sigma(t) σ ( t ) for 0 ≤ t ≤ 1 0\leq t\leq1 0 ≤ t ≤ 1 , and { x i } i = 1 N \{\mathbf{x}_i\}_i=1^N { x i } i = 1 N becomes x ( t ) \mathbf{x}(t) x ( t ) where x i = x ( i N ) \mathbf{x}_i=\mathbf{x}(\frac iN) x i = x ( N i ) if we let t ∈ { 0 , 1 N , … , N − 1 N } . t\in\{0,\frac1N,\ldots,\frac{N-1}N\}. t ∈ { 0 , N 1 , … , N N − 1 } . Then we have
x ( t + Δ t ) = x ( t ) + σ ( t + Δ t ) 2 − σ ( t ) 2 z ( t ) ≈ x ( t ) + d [ σ ( t ) 2 ] d t Δ t z ( t ) . \begin{aligned}\mathbf{x}(t+\Delta t)&=\mathbf{x}(t)+\sqrt{\sigma(t+\Delta t)^2-\sigma(t)^2}\mathbf{z}(t)\\&\approx\mathbf{x}(t)+\sqrt{\frac{d[\sigma(t)^2]}{dt}\Delta t}\:\mathbf{z}(t).\end{aligned} x ( t + Δ t ) = x ( t ) + σ ( t + Δ t ) 2 − σ ( t ) 2 z ( t ) ≈ x ( t ) + d t d [ σ ( t ) 2 ] Δ t z ( t ) . At the limit when Δ t → 0 \Delta t\to0 Δ t → 0 , the equation converges to
d x = d [ σ ( t ) 2 ] d t d w . d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}. d x = d t d [ σ ( t ) 2 ] d w . NOTE The forward sampling equation of SMLD can be written as an SDE via
d x = d [ σ ( t ) 2 ] d t d w . d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}. d x = d t d [ σ ( t ) 2 ] d w . NOTE The reverse sampling equation of SMLD can be written as an SDE via
d x = − ( d [ σ ( t ) 2 ] d t ∇ x log p t ( x ( t ) ) ) d t + d [ σ ( t ) 2 ] d t d w ‾ . d\mathbf{x}=-\left(\frac{d[\sigma(t)^2]}{dt}\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)dt+\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\overline{\mathbf{w}}. d x = − ( d t d [ σ ( t ) 2 ] ∇ x log p t ( x ( t )) ) d t + d t d [ σ ( t ) 2 ] d w . For the discrete-time iterations, we first define α ( t ) = d [ σ ( t ) 2 ] d t . \alpha(t)=\frac{d[\sigma(t)^2]}{dt}. α ( t ) = d t d [ σ ( t ) 2 ] . Then, using the same set of discretization setups as the DDPM case, we can show that
x ( t + Δ t ) − x ( t ) = − ( α ( t ) ∇ x log p t ( x ) ) Δ t − α ( t ) Δ t z ( t ) ⇒ x ( t ) = x ( t + Δ t ) + α ( t ) Δ t ∇ x log p t ( x ) + α ( t ) Δ t z ( t ) ⇒ x i − 1 = x i + α i ∇ x log p i ( x i ) + α i z i ⇒ x i − 1 = x i + ( σ i 2 − σ i − 1 2 ) ∇ x log p i ( x i ) + ( σ i 2 − σ i − 1 2 ) z i , \begin{aligned} \mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\Big(\alpha(t)\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\Big)\Delta t-\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}(t)&=\mathbf{x}(t+\Delta t)+\alpha(t)\Delta t\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})+\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+\alpha_{i}\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{\alpha_{i}}\:\mathbf{z}_{i}\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{(\sigma_{i}^{2}-\sigma_{i-1}^{2})}\:\mathbf{z}_{i}, \end{aligned} x ( t + Δ t ) − x ( t ) ⇒ x ( t ) ⇒ x i − 1 ⇒ x i − 1 = − ( α ( t ) ∇ x log p t ( x ) ) Δ t − α ( t ) Δ t z ( t ) = x ( t + Δ t ) + α ( t ) Δ t ∇ x log p t ( x ) + α ( t ) Δ t z ( t ) = x i + α i ∇ x log p i ( x i ) + α i z i = x i + ( σ i 2 − σ i − 1 2 ) ∇ x log p i ( x i ) + ( σ i 2 − σ i − 1 2 ) z i , which is identical to the SMLD reverse update equation.
4 Solving SDE# Predictor-Corrector Algorithm : If we have already trained the score function s θ ( x i , i ) s_{\boldsymbol{\theta}}(\mathbf{x}_{i}, i) s θ ( x i , i ) , we can run the score-matching equation. For example, in the N N N time steps (reverse process), we can run M M M times score-matching equation to make the correction.
Accelerate the SDE Solver:
NOTE Theorem [Variation of Constants]. Consider the ODE over the range [ s , t ] : [s,t]: [ s , t ] :
d x ( t ) d t = a ( t ) x ( t ) + b ( t ) , w h e r e x ( t 0 ) = x 0 . \frac{dx(t)}{dt}=a(t)x(t)+b(t),\quad\mathrm{where}\:x(t_0)=x_0. d t d x ( t ) = a ( t ) x ( t ) + b ( t ) , where x ( t 0 ) = x 0 . The solution is given by
x ( t ) = x 0 e A ( t ) + e A ( t ) ∫ t 0 t e − A ( τ ) b ( τ ) d τ . x(t)=x_0e^{A(t)}+e^{A(t)}\int_{t_0}^te^{-A(\tau)}b(\tau)d\tau. x ( t ) = x 0 e A ( t ) + e A ( t ) ∫ t 0 t e − A ( τ ) b ( τ ) d τ . where A ( t ) = ∫ t 0 t a ( τ ) d τ . A(t)=\int_{t_{0}}^{t}a(\tau)d\tau. A ( t ) = ∫ t 0 t a ( τ ) d τ .