Introduction to score-based generative models

With a focus on generalization and overfitting

The following Jupyter notebook contains a tutorial on the theoretical and implementation aspects of Score-generative models, also called Diffusion Models (in continuous time). It will guide you through the implementation of an SGM in JAX. Furthermore, we will study the generalization capabilities of SGMs. At the end of the notebook you will have implemented an SGM and will see how it first generalizes and after more training starts to memorize the training data. From that we can conclude that the implicit regularization induces by stopping the training early enough is crucial for the performance of SGMs.

If you are interested in a more basic tutorial on the discrete-time derivation of diffusion models, there is also in this Jupyter notebook. It concludes by deriving the continuous time formulation from the discrete-time algorithm, which can also be used as a starting point for this notebook.

If you are interested in learning more on when SGMs are able to generalize or identify the support of the data distribution, this is the topic of the NeurIPS article Score-based generative models detect manifolds. The code is based on experiments made for that article. I created the notebook for a challenge in the Accelerating generative models and nonconvex optimisation workshop held at the Alan Turing instutite.

The Jupyter notebook can be downloaded here or viewed in your webbrowser below. If you have any question feel free to contact me.
For non-code introductions on SGMs I highly recommend the blogposts by Yang Song and Lilian Weng. Furthermore, there is a full-fledged JAX implementation of SGMs available here.