In this notebook, I give a brief introduction to denoising diffusion models as well as their continuous-time limit score-based generative models. Note that many people use these names (diffusion models and score-based generative models) interchangeably for both formulations.

- Assume we are given samples $\{x_i\}_{i=1}^N$ (think images) from some unknown
**target distribution $\mu_\text{data}$**.- For example, $\mu_\text{data}$ could be the distribution of all dog images.

- Our goal is to generate more samples (images) from the same measure (that look similar).

The main idea is to **first noise** (1) and **then denoise** (2) the data:

Since sampling white noise is easy on a computer, if we can learn the transformation (2), we can sample from $\mu_\text{data}$, by first sampling white noise and then applying the learned transformation (2).

In [18]:

```
import jax.numpy as jnp
import jax
from jax.experimental.host_callback import id_print
import matplotlib.pyplot as plt
from jax.lax import scan
from jax import grad, jit, vmap
import jax.random as random
from functools import partial
import numpy as onp
rng = random.PRNGKey(2022)
import flax.linen as nn
import optax
import scipy
import seaborn as sns
sns.set_style("darkgrid")
cm = sns.color_palette("mako_r", as_cmap=True)
def heatmap_data(positions, area_min=-2, area_max=2):
def small_kernel(z, area_min, area_max):
a = jnp.linspace(area_min, area_max, 512)
x, y = jnp.meshgrid(a, a)
dist = (x - z[0])**2 + (y - z[1])**2
hm = jnp.exp(-350*dist)
return hm
#again we try to jit most of the code, but use the helper functions
#since we cannot jit all of it because of the plt functions
@jit
def produce_heatmap(positions, area_min, area_max):
return jnp.sum(vmap(small_kernel, in_axes=(0, None, None))(positions, area_min, area_max), axis=0)
hm = produce_heatmap(positions, area_min, area_max) #np.sum(vmap(small_kernel)(to_plot), axis=0)
return hm
def plot_heatmap(positions, area_min=-2, area_max=2):
"""
positions: locations of all particles in R^2, array (J, 2)
area_min: lowest x and y coordinate
area_max: highest x and y coordinate
will plot a heatmap of all particles in the area [area_min, area_max] x [area_min, area_max]
"""
hm = heatmap_data(positions, area_min, area_max)
extent = [area_min, area_max, area_max, area_min]
im = plt.imshow(hm, cmap=cm, interpolation='nearest', extent=extent)
ax = plt.gca()
ax.invert_yaxis()
return im
```

To visualize our algorithm output, we will now create a toy dataset consisting of two disconnected spheres. We assume that we have 10 samples from each sphere, i.e. 20 training samples overall. These are in $R^2$. We then want to train the algorithm to generate more samples from the underlying distribution.

Thinking of image datasets, each of the $20$ points would now represent one image and the two spheres would be the abstract distribution of all images.

In [2]:

```
def sample_sphere(J):
alphas = jnp.linspace(0, 2*jnp.pi * (1 - 1/J), J)
xs = jnp.cos(alphas)
ys = jnp.sin(alphas)
mf = jnp.stack([xs, ys], axis=1)
return mf
J = 20
sphere1 = sample_sphere(J//2) * 0.5 + 0.7
sphere2 = sample_sphere(J//2) * 0.5 - 0.7
mf = jnp.concatenate((sphere1, sphere2))
plt.scatter(mf[:, 0], mf[:, 1]);
```

The most intuitive idea to achieve our goal is the following. We just take our training dataset, add lots of noise to it (this would be transformation (1)), and then train the neural network to predict the original datapoint again (transformation (2)). We now will try to do so

This is one of the most straightforward neural networks there is. We just apply linear functions to the input and then interleave these linear functions with nonlinear activation functions. The specific choice of the nonlinear activation function is up to the user. We use `relu`

, which is one of the most popular ones.

In [3]:

```
class FullyConnected(nn.Module):
@nn.compact
def __call__(self, x):
in_size = x.shape[1]
n_hidden = 256
act = nn.relu
x = nn.Dense(n_hidden)(x)
x = act(x)
x = nn.Dense(n_hidden)(x)
x = act(x)
x = nn.Dense(n_hidden)(x)
x = act(x)
x = nn.Dense(in_size)(x)
return x
#some dummy input data. Flax is able to infer all the dimensions of the weights
#if we supply if with the kind of input data it has to expect
x = jnp.zeros(20).reshape((10, 2))
#initialize the model weights
denoiser = FullyConnected()
rng, srng = random.split(rng)
params = denoiser.init(srng, x)
#Initialize the optimizer
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)
```

The loss is as described above. We take some input `data`

, and add standard normal noise.
However, optimally, we want there to be nearly no signal left in the noisy version: The neural network should learn to predict the original data from nearly white noise. Therefore, we also multiply the image with a small scalar (`0.01`

in our case). Also, we predict the noise used to corrupt the data, instead of the data itself. Since the data can be recovered using the noisy version and subtracting the noise, this is equivalent. One can also directly predict the data, and some algorithms do so.

The update function then takes the gradient of the loss function and applies it to the parameters of the neural network. Except for some improvements that the `adam`

optimizer does (which we use), this is very close to just doing normal gradient descent.

In [4]:

```
def loss_fn(params, model, rng, data):
#Noise the Data
rng, step_rng = random.split(rng)
noise = random.normal(step_rng, data.shape)
noised_data = 0.01 * data + noise
#Predict the Noise from the output
output = model.apply(params, noised_data)
loss = jnp.mean((noise - output)**2)
return loss
@partial(jit, static_argnums=[4])
def update_step(params, rng, batch, opt_state, model):
val, grads = jax.value_and_grad(loss_fn)(params, model, rng, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return val, params, opt_state
```

Training is now basically running gradient descent with the above loss function.

In [5]:

```
N_epochs = 60_000
train_size = mf.shape[0]
losses = []
for k in range(N_epochs):
rng, step_rng = random.split(rng)
loss, params, opt_state = update_step(params, step_rng, mf, opt_state, denoiser)
losses.append(loss)
if (k+1) % 5_000 == 0:
mean_loss = onp.mean(onp.array(losses))
losses = []
print("Epoch %d,\t Loss %f " % (k+1, mean_loss))
```

We now sample standard normal noise and just apply our learned denoiser from above to that standard normal distribution. Afterwards, we plot a heatmap. Optimally, that should look similar to the training data above, generalizing to the two spheres.

In [6]:

```
def sample(rng, N_samples, model, params):
rng, step_rng = random.split(rng)
noised_data = random.normal(step_rng, (N_samples, 2))
predicted_noise = model.apply(params, noised_data)
data = 100*(noised_data - predicted_noise)
return data
N_samples = 1000
rng, srng = random.split(rng)
samples = sample(srng, N_samples, denoiser, params)
plot_heatmap(samples)
```

Out[6]:

<matplotlib.image.AxesImage at 0x7fe37c233e80>

The samples above depict a heatmap of 1000 samples that our neural network has generated. These look nothing like the two spheres from above - therefore we should try something else.

```
```