This post is an introduction on diffusion written for my younger self. If you are new to diffusion and already familiar with VAE, this may be a good entry point for you.
Diffusion has blown up over the past few years, with a huge amount of literature produced, but many differ in how they frame this modeling paradigm1. I do not intend to approach this via the most standardized or unified theory. I am not the best person to do that, nor do I think that angle helped me the most when approaching the subject (see my previous post). While many explanations derive diffusion from first principles, I found it useful to start by understanding how the objective is shaped into something we can actually train.
Diffusion and Variational Autoencoder (VAE)2 models are both generative models. There are typically 2 goals for making a generative model: to generate unseen natural samples of the learned data source, and to apply controls (conditioning) on which subtype of samples to generate. I find that the construction of VAE provides some useful intuitions to understand the training dynamics of diffusion models.
Variational Autoencoder
Model construction
VAE is a latent variable model. A latent variable model assumes that the data you observe, \(\mathbf{x}\), can be generated from some hidden (latent) variables, \(\mathbf{z}\). In VAE, \(\mathbf{z}\) is typically assumed to follow a simpler prior distribution \(p(\mathbf{z})\) (e.g. unit Gaussian). Given the above assumptions, we then train a neural network encoder (or approximate posterior) \(q_\phi(\mathbf{z}|\mathbf{x})\) and decoder \(p_\theta(\mathbf{x}|\mathbf{z})\) to map from the data space to the latent space and vice versa. Both the encoder and decoder are typically assumed to model Gaussian distributions with mean and variance parameterization \(\mu_{\phi}\), \(\Sigma_{\phi}\), \(\mu_{\theta}\), \(\Sigma_{\theta}\).
Figure 1: The directed graph model of VAE.
With a trained VAE model, we will sample \(\mathbf{z}\) from the Gaussian prior \(p(\mathbf{z})\), and then sample \(\mathbf{x}\) by passing \(\mathbf{z}\) through the decoder \(p_\theta(\mathbf{x}|\mathbf{z})\).
Loss function
The VAE loss, \(L(\mathbf{x})\), is derived as the negative of the variational lower bound.
\[ \begin{aligned} D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x})\,\|\,p_\theta(\mathbf{z}|\mathbf{x})) &= \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})} \left[ \log \frac{q_\phi(\mathbf{z}|\mathbf{x})}{p_\theta(\mathbf{z}|\mathbf{x})} \right] \\ &= \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})} \left[ \log \frac{q_\phi(\mathbf{z}|\mathbf{x})}{p_\theta(\mathbf{x}|\mathbf{z})p(\mathbf{z})/p_\theta(\mathbf{x})} \right] \\ &= \log p_\theta(\mathbf{x}) - \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] + D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x})\,\|\,p(\mathbf{z})) \\ \end{aligned} \]\[ \begin{aligned} \Rightarrow \quad -\log p_\theta(\mathbf{x}) &\le \underbrace{D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x})\,\|\,p(\mathbf{z}))}_{regularization} - \underbrace{\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})]}_{reconstruction} \\ &= \mathcal{L}(\mathbf{x}) \end{aligned} \]This derivation is where the engineering ingenuity really comes in. The math only guarantees a lower bound on the likelihood, but what makes VAEs work in practice is rewriting the objective into a form that looks like reconstruction + regularization. The reconstruction term says that, given a real sample \(\mathbf{x}\), if we sample a latent \(\mathbf{z}\) based on its posterior \(q_\phi({\mathbf{z} | \mathbf{x}})\), its likelihood under \(p_\theta({\mathbf{x} | \mathbf{z}})\) should be maximized. This term provides a stable training signal for learning how to reconstruct data from latents, while the regularization term encourages the posterior latents to stay close to a Gaussian prior. We will see a similar pattern shows up in the diffusion loss later.
Denoising diffusion probabilistic models
The high level idea of diffusion models is to define a diffusion process which gradually destroys structure in the original data and its reverse process that gradually restores structure. The motivating belief is that estimating small perturbations is more tractable than learning to describe the full data distribution with a single function. (paraphrased from foundational diffusion paper3)
This section will be largely based on the construction in the Denoising Diffusion Probabilistic Models (DDPM)4. I find the original paper a bit dense at first read, and I would attempt to inject some intuitions by linking this work to the previous section.
Model construction
Diffusion models can be thought of as deep latent variable models with multiple levels of latents, whereby \(\mathbf{x}_0\) is observed data, and \(\mathbf{x}_1\), \(\mathbf{x}_2\), …, \(\mathbf{x}_T\) are the latents.
Figure 2: The directed graph model of a denoising diffusion probabilistic model (Image source: Ho et al. 2020).
There are a few ways in which it differs from a VAE. Firstly, in VAE, \(\mathbf{z}\) is typically of a much smaller dimension. On the other hand, in diffusion models the latents are of the same size as the original data \(\mathbf{x}_0\). Next, \(q(\mathbf{x}_t | \mathbf{x}_{t-1})\) is not learnable but a predefined noise perturbation process based on a variance schedule \(\beta_1\), …, \(\beta_T\):
\[ \begin{aligned} q(\mathbf{x}_t | \mathbf{x}_{t-1}) := N(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I}) \end{aligned} \]\(q(\mathbf{x}_{1:t}|\mathbf{x}_0)\), also called the forward process, is thus a Markov chain. We then define \(p_\theta(\mathbf{x}_{0:T})\), the reverse process (or so-called denoising process), also as a Markov chain:
\[ \begin{aligned} p(\mathbf{x}_{T}) \sim N(0, \mathbf{I}) \end{aligned} \]\[ \begin{aligned} p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_{t}) := N(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta(\mathbf{x}_t, t)) \end{aligned} \]
After the model is trained, to obtain a natural sample, we will sample \(\mathbf{x}_T\) from a unit Gaussian prior, and pass it through the chain of \(p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) to arrive at a sample \(\mathbf{x}_0\).
Loss function
Similar to VAEs, we derive the loss function based on the variational lower bound, though this time we need to break down the loss terms into each subsequent timestep:
\[ \begin{aligned} -\log p(\mathbf{x}_0) &\le D_{\mathrm{KL}}(q(\mathbf{x}_{1:T}|\mathbf{x}_0)\,\|\,p(\mathbf{x}_{1:T})) - \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}[\log p_\theta(\mathbf{x}_0|\mathbf{x}_{1:T})] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\bigg[-\log p(\mathbf{x}_T)-\sum\limits_{t \ge 1}\log\frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}\bigg] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \Bigg[- \log p(\mathbf{x}_T) - \sum_{t>1} \log \frac{p_\theta(\mathbf{x}_{t-1} \mid \mathbf{x}_t)}{q(\mathbf{x}_t \mid \mathbf{x}_{t-1})} - \log \frac{p_\theta(\mathbf{x}_0 \mid \mathbf{x}_1)}{q(\mathbf{x}_1 \mid \mathbf{x}_0)}\Bigg] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \Bigg[- \log p(\mathbf{x}_T) - \sum_{t>1} \log \left(\frac{p_\theta(\mathbf{x}_{t-1} \mid \mathbf{x}_t)}{q(\mathbf{x}_{t-1}\mid \mathbf{x}_t, \mathbf{x}_0)} \cdot \frac{q(\mathbf{x}_{t-1} \mid \mathbf{x}_0)}{q(\mathbf{x}_t \mid \mathbf{x}_0)} \right)- \log \frac{p_\theta(\mathbf{x}_0 \mid \mathbf{x}_1)}{q(\mathbf{x}_1 \mid \mathbf{x}_0)} \Bigg] \\ &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} \Bigg[- \log \frac{p(\mathbf{x}_T)}{q(\mathbf{x}_T \mid \mathbf{x}_0)} - \sum_{t>1} \log \frac{p_\theta(\mathbf{x}_{t-1} \mid \mathbf{x}_t)}{q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0)} - \log p_\theta(\mathbf{x}_0 \mid \mathbf{x}_1) \Bigg] \\ &= \underbrace{\mathbb{E}_{q(\mathbf{x}_T|\mathbf{x}_0)} \Bigg[D_{\mathrm{KL}}\big(q(\mathbf{x}_T \mid \mathbf{x}_0)\,\|\,p(\mathbf{x}_T)\big)\Bigg]}_{L_T} + \sum_{t>1} \underbrace{\mathbb{E}_{q(\mathbf{x}_t|\mathbf{x}_0)}\Bigg[D_{\mathrm{KL}}\big(q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0)\,\|\,p_\theta(\mathbf{x}_{t-1} \mid \mathbf{x}_t)\big)\Bigg]}_{L_{t-1}} \\ & \quad - \underbrace{\mathbb{E}_{q(\mathbf{x}_1|\mathbf{x}_0)}\log p_\theta(\mathbf{x}_0 \mid \mathbf{x}_1)}_{L_0} \\ &= \mathcal{L}(\mathbf{x}_0) \\ \end{aligned} \]\(L_T\) is like the regularization term in VAE loss, but can be ignored here since \(q\) is predefined. It is also not necessary in this case since given the right \(\beta_t\) schedule, \(q(\mathbf{x}_T | \mathbf{x}_0)\) should be close to \(N(0, \mathbf{I})\). \(L_0\) is like the reconstruction term, maximizing the observed data likelihood given the predictive distribution conditioned on latent samples from \(q(\mathbf{x}_1 | \mathbf{x}_0)\). I would argue that \(L_1\), … \(L_{T-1}\) in this case function somewhat like reconstruction terms as well. Given \(\mathbf{x}_t\) sampled from \(q(\mathbf{x}_t | \mathbf{x}_0)\), these KL terms are pushing the predictive latent distribution \(p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)\) to match the posterior \(q(\mathbf{x}_{t-1} | \mathbf{x}_0, \mathbf{x}_t)\). As the forward process adds noise gradually, samples from \(q(\mathbf{x}_t|\mathbf{x}_0)\) remain closely tied to the original data, so the posterior is both tractable and naturally provides a relevant training signal. In contrast to VAEs, we do not need to learn an inference model to produce relevant latent samples during training. Minimizing \(L\) then inherently maps out a discrete-time probability distribution path, from \(\mathbf{x}_T\) towards plausible \(\mathbf{x}_0\) samples following the reverse process.
Objective reparameterization as noise estimation
Another engineering detail worth mentioning is the parameterization of the loss. A good reason to derive the loss terms as above is that \(q(\mathbf{x}_{t-1} \mid \mathbf{x}_0, \mathbf{x}_t)\) is tractable and is actually a Gaussian, making the KL divergence computable in closed form (You can find its derivation in the paper):
\[ \text{Given} \quad \alpha_t := 1 - \beta_t \quad \text{and}\quad \bar{\alpha}_t := \prod_{s=1}^{t} \alpha_s, \]\[ q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}\big(\mathbf{x}_{t-1}; \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}\big), \]\[ \text{where} \quad \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) := \frac{\sqrt{\bar{\alpha}_{t-1}} \, \beta_t}{1 - \bar{\alpha}_t} \, \mathbf{x}_0 + \frac{\sqrt{\alpha_t}\,(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \, \mathbf{x}_t \quad\quad \text{and} \quad \tilde{\beta}_t := \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \, \beta_t. \]For the reverse step parameterization, the authors decide to set \(\Sigma_\theta(\mathbf{x}_t, t):=\sigma^2_t \mathbf{I}\) as a constant in this work. Since we are trying to match each reverse step \(p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_{t})\) to the posterior \(q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0)\), it is reasonable to try to parameterize the \(\mu_\theta(\mathbf{x}_t, t)\) using the functional form of the posterior mean:
\[ \begin{aligned} \mu_\theta(\mathbf{x}_t, t) &:= \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_{0\theta}(\mathbf{x}_t, t)) \end{aligned} \]So in the functional form above, our network will have to compute a quantity \(\mathbf{x}_{0\theta}\) given \(\mathbf{x}_t\), \(t\). Before we proceed, we note that \(q(\mathbf{x}_t | \mathbf{x}_0)\) is actually a Gaussian with the form below:
\[ q(\mathbf{x}_t \mid \mathbf{x}_0) = \mathcal{N}\big(\mathbf{x}_t;\, \sqrt{\bar{\alpha}_t}\,\mathbf{x}_0,\,(1 - \bar{\alpha}_t)\mathbf{I}\big) \]If we were to sample from \(q(\mathbf{x}_t \mid \mathbf{x}_0)\), we would compute \(\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\,\mathbf{x}_0 + \sqrt{(1 - \bar{\alpha}_t)}\epsilon\) where \(\epsilon \sim N(0,\mathbf{I})\). This inspires an alternative parameterization below where the network computes a quantity \(\epsilon_\theta(\mathbf{x}_t, t)\) instead:
\[ \begin{aligned} \mu_\theta(\mathbf{x}_t, t) &= \tilde{\mu}_t(\mathbf{x}_t, \frac{\mathbf{x}_t-\sqrt{(1 - \bar{\alpha}_t)}\epsilon_\theta(\mathbf{x}_t, t)}{\sqrt{\bar{\alpha}_t}}) \end{aligned} \]With the above parameterization, each of the loss terms \(L_{t-1}\) gets simplified beautifully to a weighted MSE of noise estimation:
\[ \begin{aligned} L_{t-1} &= \mathbb{E}_{q(\mathbf{x}_t | \mathbf{x}_0)} \left[ \frac{1}{2\sigma_t^2} \left\| \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) - \mu_\theta(\mathbf{x}_t, t) \right\|^2 \right] + C \\ &= \mathbb{E}_{q(\mathbf{x}_t | \mathbf{x}_0)} \left[ \frac{1}{2\sigma_t^2} \left\| \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) - \tilde{\mu}_t(\mathbf{x}_t, \frac{\mathbf{x}_t-\sqrt{(1 - \bar{\alpha}_t)}\epsilon_\theta(\mathbf{x}_t, t)}{\sqrt{\bar{\alpha}_t}}) \right\|^2 \right] + C \\ &= \mathbb{E}_{q(\mathbf{x}_t | \mathbf{x}_0)} \left[\frac{1}{2\sigma_t^2} \left\|\frac{\sqrt{\bar{\alpha}_{t-1}} \, \beta_t}{1 - \bar{\alpha}_t} \, \left[\mathbf{x}_0 - \frac{\mathbf{x}_t-\sqrt{(1 - \bar{\alpha}_t)}\epsilon_\theta(\mathbf{x}_t, t)}{\sqrt{\bar{\alpha}_t}}\right]\right\|^2 \right] + C \\ &= \mathbb{E}_{\epsilon \sim N(0, \mathbf{I})} \left[\frac{1}{2\sigma_t^2} \left\|\frac{\sqrt{\bar{\alpha}_{t-1}} \, \beta_t}{1 - \bar{\alpha}_t} \, \left[\mathbf{x}_0 - \frac{\sqrt{\bar{\alpha}_t}\,\mathbf{x}_0 + \sqrt{(1 - \bar{\alpha}_t)}\epsilon-\sqrt{(1 - \bar{\alpha}_t)}\epsilon_\theta(\sqrt{\bar{\alpha}_t}\,\mathbf{x}_0 + \sqrt{(1 - \bar{\alpha}_t)}\epsilon, t)}{\sqrt{\bar{\alpha}_t}}\right]\right\|^2 \right] + C \\ &= \mathbb{E}_{\epsilon \sim N(0, \mathbf{I})} \left[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar{\alpha}_t)} \left\| \epsilon - \epsilon_\theta\left(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\,\epsilon,\, t\right) \right\|^2 \right] + C \end{aligned} \]How about the decoder likelihood term \(L_0\)? For discrete input like images, the author scales the \(\mathbf{x}_0\) to the range \([-1,1]\), and assumes that each dimension gets cut into 256 equal bins. The likelihood \(p_\theta(\mathbf{x}_0 |\mathbf{x}_1)\) is then estimated by multiplying the pdf value of \(N(\mu_\theta({\mathbf{x}_1}, 1), \sigma^2 \mathbf{I})\) at \(\mathbf{x}_0\) by the bin size. Under this discretized Gaussian likelihood assumption, \(L_0\) can also be written in a form equivalent to a mean squared error, up to a constant factor.
In practice, the authors drop the multiplying constants and assume equal weighting on the MSE of noise estimation across \(L_0, ...., L_{T-1}\). This can be interpreted as a recalibrated prioritization of prediction at each time step. During training, for each real data sample of \(\mathbf{x}_0\), we will randomly pick \(t \in [1,...,T]\), and minimize its corresponding \(L_{t-1}\).
Conclusion
Diffusion is like a VAE with a deep hierarchy of high dimensional latents, whereby each hierarchy is connected by a noise/denoise relationship. Just like in VAE, the diffusion loss is derived from variational lower bound and can be viewed as mostly a reconstruction driven objective. It gets simplified nicely into a weighted noise estimation objective with a well-engineered parameterization.
If you were as as lost as I was at some point about the training mechanics of diffusion models, I hope this post provides some useful intuitions on what it is really optimizing towards.
Other interpretation angles
Besides the VAE perspective, I find that the perspecitve of noise conditioned score network (NCSN) + annealed langevin dynamics sampling (ALD) really neat as well. For interested readers please check out this great blog post5 written by the OG. There was also a great follow-up work that showed that DDPM and NCSN+ALD both correspond to instances of a stochastic differential equation formulation6.
References
-
Dieleman, S. (2023). Perspectives on diffusion. Retrieved from https://sander.ai/2023/07/20/perspectives.html ↩︎
-
Kingma, D. P., & Welling, M. (2013). Auto-encoding variational Bayes. arXiv preprint arXiv:1312.6114. ↩︎
-
Sohl-Dickstein, J., Weiss, E. A., Maheswaranathan, N., & Ganguli, S. (2015). Deep unsupervised learning using nonequilibrium thermodynamics. arXiv. https://arxiv.org/abs/1503.03585 ↩︎
-
Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. arXiv:2006.11239 ↩︎
-
Song, Y. (2021, May 5). Generative modeling by estimating gradients of the data distribution. https://yang-song.net/blog/2021/score/ ↩︎
-
Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv. https://arxiv.org/abs/2011.13456 ↩︎