1 Forward and Backward Iterations in SDE#  The basic idea comes from gradient descent algorithm.
x i = x i − 1 − τ ′ ∇ f ( x i − 1 ) + z i − 1 ⟹ x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t ) ) + z ( t ) . \begin{align} \mathbf{x}_i=\mathbf{x}_{i-1}-\tau'\nabla f(\mathbf{x}_{i-1})+\mathbf{z}_{i-1}\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t). \end{align} x i  = x i − 1  − τ ′ ∇ f ( x i − 1  ) + z i − 1  ⟹ x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t )) + z ( t ) .   Now, let’ s define a random process w ( t ) \mathbf{w} ( t) w ( t ) ( t ) = w ( t + Δ t ) − w ( t ) ≈ d w ( t ) d t Δ t ( t) = \mathbf{w} ( t+ \Delta t) - \mathbf{w} ( t) \approx \frac {d\mathbf{w} ( t) }{dt}\Delta t ( t ) = w ( t + Δ t ) − w ( t ) ≈ d t d w ( t )  Δ t Δ t . \Delta t. Δ t . w ( t ) \mathbf{w}(t) w ( t ) ( t ) (t) ( t ) ( t ) (t) ( t ) 
x ( t + Δ t ) = x ( t ) − τ ∇ f ( x ( t ) ) + z ( t ) ⟹ x ( t + Δ t ) − x ( t ) = − τ ∇ f ( x ( t ) ) + w ( t + Δ t ) − w ( t ) ⟹ d x = − τ ∇ f ( x ) d t + d w . \begin{align} \mathbf{x}(t+\Delta t)&=\mathbf{x}(t)-\tau\nabla f(\mathbf{x}(t))+\mathbf{z}(t)\\ \Longrightarrow\quad\mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\tau\nabla f(\mathbf{x}(t))+\mathbf{w}(t+\Delta t)-\mathbf{w}(t)\\ \Longrightarrow\quad \quad \quad \quad \quad \quad \quad d\mathbf{x}&=-\tau\nabla f(\mathbf{x})dt+d\mathbf{w}. \end{align} x ( t + Δ t ) ⟹ x ( t + Δ t ) − x ( t ) ⟹ d x  = x ( t ) − τ ∇ f ( x ( t )) + z ( t ) = − τ ∇ f ( x ( t )) + w ( t + Δ t ) − w ( t ) = − τ ∇ f ( x ) d t + d w .   Note that we often use d w = Δ t z ( t ) d\mathbf{w} =\sqrt{ \Delta t} \mathbf{z}(t) d w = Δ t  z ( t ) 
WARNING Forward Diffusion 
d x = f ( x , t ) ⏟ d r i f t   d t + g ( t ) ⏟ d i f f u s i o n   d w . d\mathbf{x}=\underbrace{\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}\:dt+\underbrace{g(t)}_{\mathrm{diffusion}}\:d\mathbf{w}. d x = drift f ( x , t )   d t + diffusion g ( t )   d w . WARNING Reverse SDE 
d x = [ f ( x , t ) ⏟ d r i f t − g ( t ) 2 ∇ x log  p t ( x ) ] ⏟ score function   d t + g ( t ) d w ‾ ⏟ reverse-time diffusion , d\mathbf{x}=\underbrace{[\mathbf{f}(\mathbf{x},t)}_{\mathrm{drift}}-g(t)^2\underbrace{\nabla_\mathbf{x}\log p_t(\mathbf{x})]}_{\text{score function}}\:dt\quad+\underbrace{g(t)d\overline{\mathbf{w}}}_{\text{reverse-time diffusion}}, d x = drift [ f ( x , t )   − g ( t ) 2 score function ∇ x  log  p t  ( x )]   d t + reverse-time diffusion g ( t ) d w   , where p t ( x ) p_t(\mathbf{x}) p t  ( x ) x \mathbf{x} x t t t w ‾ \overline{\mathbf{w}} w 
2 Stochastic Differential Equation for DDPM#  We consider the discrete-time DDPM iteration. For i = 1 , 2 , … , N i=1,2,\dots,N i = 1 , 2 , … , N 
x i = 1 − β i x i − 1 + β i z i − 1 , z i − 1 ∼ N ( 0 , I ) . \begin{aligned}\mathbf{x}_i&=\sqrt{1-\beta_i}\mathbf{x}_{i-1}+\sqrt{\beta_i}\mathbf{z}_{i-1},\quad&\mathbf{z}_{i-1}\sim\mathcal{N}(0,\mathbf{I}).\end{aligned} x i   = 1 − β i   x i − 1  + β i   z i − 1  ,  z i − 1  ∼ N ( 0 , I ) .  WARNING The forward sampling equation of DDPM  can be written as an SDE via
d x = − β ( t ) 2   x ⏟ = f ( x , t ) d t + β ( t ) ⏟ = g ( t ) d w . d\mathbf{x}=\underbrace{-\frac{\beta(t)}2\textbf{ x}}_{=\mathbf{f}(\mathbf{x},t)}dt+\underbrace{\sqrt{\beta(t)}}_{=g(t)}d\mathbf{w}. d x = = f ( x , t ) − 2 β ( t )   x   d t + = g ( t ) β ( t )    d w . Note that here w \mathbf{w} w 
DDPM iteration itself is solving the SDE. (a first order method) WARNING The reverse sampling equation of DDPM  can be written as an SDE via
d x = − β ( t ) [ x 2 + ∇ x log  p t ( x ) ] d t + β ( t ) d w ‾ . d\mathbf{x}=-\beta(t)\left[\frac{\mathbf{x}}{2}+\nabla_{\mathbf{x}}\log p_t(\mathbf{x})\right]dt+\sqrt{\beta(t)}d\overline{\mathbf{w}}. d x = − β ( t ) [ 2 x  + ∇ x  log  p t  ( x ) ] d t + β ( t )  d w . x ( t ) − x ( t − Δ t ) = − β ( t ) Δ t ⌊ x ( t ) 2 + ∇ x log  p t ( x ( t ) ) ⌋ − β ( t ) Δ t z ( t )    ⟹    x ( t − Δ t ) = x ( t ) + β ( t ) Δ t [ x ( t ) 2 + ∇ x log  p t ( x ( t ) ) ] + β ( t ) Δ t z ( t ) . \begin{aligned} &\mathbf{x}(t)-\mathbf{x}(t-\Delta t)=-\beta(t)\Delta t\left\lfloor\frac{\mathbf{x}(t)}{2}+\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right\rfloor-\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &\implies\quad\mathbf{x}(t-\Delta t)=\mathbf{x}(t)+\beta(t)\Delta t\left[\frac{\mathbf{x}(t)}{2}+\nabla_{\mathbf{x}}\operatorname{log}p_t(\mathbf{x}(t))\right]+\sqrt{\beta(t)\Delta t}\mathbf{z}(t). \\ \end{aligned}  x ( t ) − x ( t − Δ t ) = − β ( t ) Δ t ⌊ 2 x ( t )  + ∇ x  log  p t  ( x ( t )) ⌋ − β ( t ) Δ t  z ( t ) ⟹ x ( t − Δ t ) = x ( t ) + β ( t ) Δ t [ 2 x ( t )  + ∇ x  log p t  ( x ( t )) ] + β ( t ) Δ t  z ( t ) .  By grouping the terms, and assuming that β ( t ) Δ t ≪ 1 \beta(t)\Delta t\ll1 β ( t ) Δ t ≪ 1 
x ( t − Δ t ) = x ( t ) [ 1 + β ( t ) Δ t 2 ] + β ( t ) Δ t ∇ x log  p t ( x ( t ) ) + β ( t ) Δ t z ( t ) ≈ x ( t ) [ 1 + β ( t ) Δ t 2 ] + β ( t ) Δ t ∇ x log  p t ( x ( t ) ) + ( β ( t ) Δ t ) 2 2 ∇ x log  p t ( x ( t ) ) + β ( t ) Δ t z ( t ) = [ 1 + β ( t ) Δ t 2 ] ( x ( t ) + β ( t ) Δ t ∇ x log  p t ( x ( t ) ) ) + β ( t ) Δ t z ( t ) \begin{aligned} \mathbf{x}(t-\Delta t)& \begin{aligned}&=\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}{2}\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t)\end{aligned} \\ &\approx\mathbf{x}(t)\left[1+\frac{\beta(t)\Delta t}2\right]+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\frac{(\beta(t)\Delta t)^2}2\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \\ &=\left[1+\frac{\beta(t)\Delta t}{2}\right]\left(\mathbf{x}(t)+\beta(t)\Delta t\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)+\sqrt{\beta(t)\Delta t}\mathbf{z}(t) \end{aligned} x ( t − Δ t )   = x ( t ) [ 1 + 2 β ( t ) Δ t  ] + β ( t ) Δ t ∇ x  log  p t  ( x ( t )) + β ( t ) Δ t  z ( t )  ≈ x ( t ) [ 1 + 2 β ( t ) Δ t  ] + β ( t ) Δ t ∇ x  log  p t  ( x ( t )) + 2 ( β ( t ) Δ t ) 2  ∇ x  log  p t  ( x ( t )) + β ( t ) Δ t  z ( t ) = [ 1 + 2 β ( t ) Δ t  ] ( x ( t ) + β ( t ) Δ t ∇ x  log  p t  ( x ( t )) ) + β ( t ) Δ t  z ( t )  Then, following the discretization scheme, we can show that
x i − 1 = ( 1 + β i 2 ) [ x i + β i 2 ∇ x log  p i ( x i ) ] + β i z i ≈ 1 1 − β i [ x i + β i 2 ∇ x log  p i ( x i ) ] + β i z i , \begin{aligned}\mathbf{x}_{i-1}&=(1+\frac{\beta_i}2)\bigg[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\bigg]+\sqrt{\beta_i}\mathbf{z}_i\\&\approx\frac1{\sqrt{1-\beta_i}}\left[\mathbf{x}_i+\frac{\beta_i}2\nabla_\mathbf{x}\log p_i(\mathbf{x}_i)\right]+\sqrt{\beta_i}\mathbf{z}_i,\end{aligned} x i − 1   = ( 1 + 2 β i   ) [ x i  + 2 β i   ∇ x  log  p i  ( x i  ) ] + β i   z i  ≈ 1 − β i   1  [ x i  + 2 β i   ∇ x  log  p i  ( x i  ) ] + β i   z i  ,  where p i ( x ) p_i(\mathbf{x}) p i  ( x ) x \mathbf{x} x i i i ∇ x log  p i ( x i ) \nabla_\mathbf{x}\log p_i(\mathbf{x}_i) ∇ x  log  p i  ( x i  ) s θ ( x i ) \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_i) s θ  ( x i  ) 
3 Stochastic Differential Equation for SMLD#  Although there isn’t a forward diffusion step, if we divide the noise scale (e.g., x + σ z \mathbf{x}+\sigma \mathbf{z} x + σ z N N N 
x i = x i − 1 + σ i 2 − σ i − 1 2 z i − 1 , i = 1 , 2 , … , N . \mathbf{x}_i=\mathbf{x}_{i-1}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}\mathbf{z}_{i-1},\quad i=1,2,\ldots,N. x i  = x i − 1  + σ i 2  − σ i − 1 2   z i − 1  , i = 1 , 2 , … , N . If we assume that the variance of x i − 1 \mathbf{x}_{i-1} x i − 1  σ i − 1 2 \sigma_{i-1}^2 σ i − 1 2  
V a r [ x i ] = V a r [ x i − 1 ] + ( σ i 2 − σ i − 1 2 ) = σ i − 1 2 + ( σ i 2 − σ i − 1 2 ) = σ i 2 . \begin{aligned} \mathrm{Var}[\mathbf{x}_{i}]&=\mathrm{Var}[\mathbf{x}_{i-1}]+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\\&=\sigma_{i-1}^{2}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})=\sigma_{i}^{2}. \end{aligned} Var [ x i  ]  = Var [ x i − 1  ] + ( σ i 2  − σ i − 1 2  ) = σ i − 1 2  + ( σ i 2  − σ i − 1 2  ) = σ i 2  .  Therefore, given a sequence of noise levels, above equation will indeed generate estimates x i \mathbf{x}_i x i  
Assuming that in the limit  { σ i } i = 1 N \operatorname*{limit}\left\{\sigma_i\right\}_{i=1}^N limit { σ i  } i = 1 N  σ ( t ) \sigma(t) σ ( t ) 0 ≤ t ≤ 1 0\leq t\leq1 0 ≤ t ≤ 1 { x i } i = 1 N \{\mathbf{x}_i\}_i=1^N { x i  } i  = 1 N x ( t ) \mathbf{x}(t) x ( t ) x i = x ( i N ) \mathbf{x}_i=\mathbf{x}(\frac iN) x i  = x ( N i  ) t ∈ { 0 , 1 N , … , N − 1 N } . t\in\{0,\frac1N,\ldots,\frac{N-1}N\}. t ∈ { 0 , N 1  , … , N N − 1  } . 
x ( t + Δ t ) = x ( t ) + σ ( t + Δ t ) 2 − σ ( t ) 2 z ( t ) ≈ x ( t ) + d [ σ ( t ) 2 ] d t Δ t   z ( t ) . \begin{aligned}\mathbf{x}(t+\Delta t)&=\mathbf{x}(t)+\sqrt{\sigma(t+\Delta t)^2-\sigma(t)^2}\mathbf{z}(t)\\&\approx\mathbf{x}(t)+\sqrt{\frac{d[\sigma(t)^2]}{dt}\Delta t}\:\mathbf{z}(t).\end{aligned} x ( t + Δ t )  = x ( t ) + σ ( t + Δ t ) 2 − σ ( t ) 2  z ( t ) ≈ x ( t ) + d t d [ σ ( t ) 2 ]  Δ t  z ( t ) .  At the limit when Δ t → 0 \Delta t\to0 Δ t → 0 
d x = d [ σ ( t ) 2 ] d t   d w . d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}. d x = d t d [ σ ( t ) 2 ]   d w . NOTE The forward sampling equation of SMLD  can be written as an SDE via
d x = d [ σ ( t ) 2 ] d t   d w . d\mathbf{x}=\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\mathbf{w}. d x = d t d [ σ ( t ) 2 ]   d w . NOTE The reverse sampling equation of SMLD  can be written as an SDE via
d x = − ( d [ σ ( t ) 2 ] d t ∇ x log  p t ( x ( t ) ) ) d t + d [ σ ( t ) 2 ] d t   d w ‾ . d\mathbf{x}=-\left(\frac{d[\sigma(t)^2]}{dt}\nabla_\mathbf{x}\log p_t(\mathbf{x}(t))\right)dt+\sqrt{\frac{d[\sigma(t)^2]}{dt}}\:d\overline{\mathbf{w}}. d x = − ( d t d [ σ ( t ) 2 ]  ∇ x  log  p t  ( x ( t )) ) d t + d t d [ σ ( t ) 2 ]   d w . For the discrete-time iterations, we first define α ( t ) = d [ σ ( t ) 2 ] d t . \alpha(t)=\frac{d[\sigma(t)^2]}{dt}. α ( t ) = d t d [ σ ( t ) 2 ]  . 
x ( t + Δ t ) − x ( t ) = − ( α ( t ) ∇ x log  p t ( x ) ) Δ t − α ( t ) Δ t   z ( t ) ⇒ x ( t ) = x ( t + Δ t ) + α ( t ) Δ t ∇ x log  p t ( x ) + α ( t ) Δ t   z ( t ) ⇒ x i − 1 = x i + α i ∇ x log  p i ( x i ) + α i   z i ⇒ x i − 1 = x i + ( σ i 2 − σ i − 1 2 ) ∇ x log  p i ( x i ) + ( σ i 2 − σ i − 1 2 )   z i , \begin{aligned} \mathbf{x}(t+\Delta t)-\mathbf{x}(t)&=-\Big(\alpha(t)\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})\Big)\Delta t-\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}(t)&=\mathbf{x}(t+\Delta t)+\alpha(t)\Delta t\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})+\sqrt{\alpha(t)\Delta t}\:\mathbf{z}(t)\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+\alpha_{i}\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{\alpha_{i}}\:\mathbf{z}_{i}\\ \Rightarrow \quad \quad \quad \quad \quad \quad \mathbf{x}_{i-1}&=\mathbf{x}_{i}+(\sigma_{i}^{2}-\sigma_{i-1}^{2})\nabla_{\mathbf{x}}\log p_{i}(\mathbf{x}_{i})+\sqrt{(\sigma_{i}^{2}-\sigma_{i-1}^{2})}\:\mathbf{z}_{i}, \end{aligned} x ( t + Δ t ) − x ( t ) ⇒ x ( t ) ⇒ x i − 1  ⇒ x i − 1   = − ( α ( t ) ∇ x  log  p t  ( x ) ) Δ t − α ( t ) Δ t  z ( t ) = x ( t + Δ t ) + α ( t ) Δ t ∇ x  log  p t  ( x ) + α ( t ) Δ t  z ( t ) = x i  + α i  ∇ x  log  p i  ( x i  ) + α i   z i  = x i  + ( σ i 2  − σ i − 1 2  ) ∇ x  log  p i  ( x i  ) + ( σ i 2  − σ i − 1 2  )  z i  ,  which is identical to the SMLD reverse update equation.
4 Solving SDE#  Predictor-Corrector Algorithm : If we have already trained the score function s θ ( x i , i ) s_{\boldsymbol{\theta}}(\mathbf{x}_{i}, i) s θ  ( x i  , i ) N N N M M M 
Accelerate the SDE Solver: 
NOTE Theorem  [Variation of Constants]. Consider the ODE over the range [ s , t ] : [s,t]: [ s , t ] : 
d x ( t ) d t = a ( t ) x ( t ) + b ( t ) , w h e r e   x ( t 0 ) = x 0 . \frac{dx(t)}{dt}=a(t)x(t)+b(t),\quad\mathrm{where}\:x(t_0)=x_0. d t d x ( t )  = a ( t ) x ( t ) + b ( t ) , where x ( t 0  ) = x 0  . The solution is given by
x ( t ) = x 0 e A ( t ) + e A ( t ) ∫ t 0 t e − A ( τ ) b ( τ ) d τ . x(t)=x_0e^{A(t)}+e^{A(t)}\int_{t_0}^te^{-A(\tau)}b(\tau)d\tau. x ( t ) = x 0  e A ( t ) + e A ( t ) ∫ t 0  t  e − A ( τ ) b ( τ ) d τ . where A ( t ) = ∫ t 0 t a ( τ ) d τ . A(t)=\int_{t_{0}}^{t}a(\tau)d\tau. A ( t ) = ∫ t 0  t  a ( τ ) d τ .