import torch
import torch.nn as nn
import torch.nn.functional as F
class VPSDE(nn.Module):
def __init__(self, beta_min, beta_max, eps, score_model):
super().__init__()
self.beta_min = beta_min
self.beta_max = beta_max
self.eps = eps
self.score_model = score_model
def get_integrated_beta_t(self, t):
return self.beta_min*t + 0.5 * (self.beta_max - self.beta_min) * t**2
def get_beta_t(self, t):
return self.beta_min + (self.beta_max - self.beta_min) * t
def get_integrated_alpha_t(self, t):
return torch.exp(-0.5 * self.get_integrated_beta_t(t))
def get_integrated_sigma_t(self, t):
return torch.sqrt(1 - torch.exp(-self.get_integrated_beta_t(t)))
def perturbation_kernel(self, x_0, t, noise=None):
alpha_t = self.get_integrated_alpha_t(t).view(-1, 1, 1, 1)
sigma_t = self.get_integrated_sigma_t(t).view(-1, 1, 1, 1)
if noise is None:
noise = torch.randn_like(x_0)
x_t = alpha_t * x_0 + sigma_t * noise
return x_t, noise, sigma_t
def sde(self, x, t):
drift = -0.5 * self.get_beta_t(t).view(-1, 1, 1, 1) * x
diffusion = torch.sqrt(self.get_beta_t(t)).view(-1, 1, 1, 1)
return drift, diffusion
def sample(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - diffusion**2*score) * delta_t + diffusion * torch.sqrt(-delta_t)*torch.randn_like(x_t)
x_t = x_prevt
return x_t
def sample_pf_ode(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - 0.5*diffusion**2*score) * delta_t
x_t = x_prevt
return x_t
def loss(self, batch):
x_0, y = batch
t = self.eps + (1- self.eps)*torch.rand((x_0.shape[0], ))
x_t, noise, sigma_t = self.perturbation_kernel(x_0, t)
s_theta = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
loss = F.mse_loss(s_theta, -noise/sigma_t, reduction='mean')
return lossScore Matching
Deep Learning
Notes on Score Matching
TipReferences
- Minimal PyTorch implementation accompanying these notes:
Motivation
- Objective: Given a complex data distribution how can we go from low probability to high probability
- Solution: Use gradient, which will guide to go towards high density
- So if we somehow have access to the gradient of p(x) i.e. data then we can use it.
- p(x) = \frac{f(x)}{Z}
- But according to property of probabilities
- Z = \int p_{data}(x)dx = 1 which is intractable
- Thus \triangledown_xp_{data}(x) = \triangledown_x\frac{f(x)}{Z} is also intractable and also numerically unstable in low density regions
- So instead of that we can consider \triangledown_xlog(p_{data}(x))
- It is tractable
- \triangledown_xlog(p_{data}(x)) = \triangledown_xlog(f_{data}(x)) - \triangledown_xlog(Z)
- \triangledown_xlog(p_{data}(x)) = \triangledown_xlog(f_{data}(x)) since Z is independent of x so it can be treated as constant.
- It Points in the same direction as \triangledown_xp_{data}(x)
- \triangledown log(p) = \frac{\triangledown p}{p}
- Numerically stable
- Since gradient depends on ratio of two numbers which are in similar range
- Similar range because if \triangledown p is small then there is possibility that p is also small
- It is tractable
- So we can follow the score i.e gradient of log(p(x) with respect to x to move towards data points
- It is possible that following the scores, we might always end up with data points which has high probability density and thus we might not be able to move to data points of low probability.
NoteLangevin Sampling
- One solution is to use Langevin Sampling
- x_t = x_{t-1} + \frac{\alpha}{2}\triangledown_{x}log(p_{data}(x_{t-1})) + \sqrt{\alpha}\epsilon_t
- where:
- \triangledown_{x}log(p_{data}(x_{t-1})) helps in going towards high density area
- \sqrt{\alpha}\epsilon_t allows to explore the region
But we don’t have \triangledown_{x}log(p_{data}(x_{t-1})), so we try to estimate the score with S_{\theta}(x)
which gives:
- L_{SM} = \mathbb{E}_{x}[||S_{\theta}(x) - \triangledown_{x}log(p_{data}(x))||^2]
But we don’t have access to \triangledown_{x}log(p_{data}(x))
There are various attempts to estimate the score, but let’s focus on more common one Denoising Score Matching
Denoising Score Matching
NoteScore of a Gaussian Distribution
- x \sim N(0, \sigma^2)
- Probability distribution function is
- p(x) = \frac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}}
- Score is:
- S(x) = -\frac{x-\mu}{\sigma^2}
- The idea is to add noise to data, since we can get score of gaussian
- \tilde{x} = x + \sigma\epsilon \epsilon \sim N(0, I)
- Therefore:
- q_\sigma(\tilde{x} \mid x) = N(0, \sigma^2I)
- \triangledown_{\tilde{x}}log(q_{\sigma}(\tilde{x} \mid x)) = -\frac{\tilde{x}-x}{\sigma^2}
- Adding noise to distribution
- q_{\sigma}(\tilde{x}) = \int{q_{\sigma}(\tilde{x} \mid x)p_{data}(x)dx}
- where q_{\sigma}(\tilde{x} \mid x) is perturbation kernel
- It is basically take all the points from data sample and transform them into noisy samples where \sigma tells how much noise you are adding
- Estimating the score of the noised distribution
- Definition of score matching
- L_{SM}(q_{\sigma}) = \mathbb{E}_{\tilde{x}}[||S_{\theta}(\tilde{x}) - \triangledown_{\tilde{x}}log(q_{\sigma}(\tilde{x}))||^2]
- From some maths
- L_{SM}(q_{\sigma}) = \mathbb{E}_{\tilde{x}}[||S_{\theta}(\tilde{x}) - \triangledown_{\tilde{x}}log(q_{\sigma}(\tilde{x} \mid x))||^2]
- where \triangledown_{\tilde{x}}log(q_{\sigma}(\tilde{x} \mid x)) is tractable and equal to -\frac{\tilde{x}-x}{\sigma^2}
- Definition of score matching
- We now have a tractable loss
- However in L_{SM}(q_{\sigma}), q_{\sigma} is not exactly p_{data}
- let’s say we used \sigma << 1:
- q_{\sigma} will be close to p_{data} but poor estimation in low density regions.
- if we use \sigma >> 1:
- q_{\sigma} will be far from p_{data} but good estimations in low density regions.
- let’s say we used \sigma << 1:
- So how do we achieve best of both worlds
- The idea is to have score matching with varying noise
- \sigma_{1} < \sigma_{2} ... < \sigma_{T}
- S_{\theta}(x) -> S_{\theta}(x, \sigma_{i})
- which gives loss as:
- L_{NCSN} = \sum_{i=1}^{L} \lambda(i)\mathbb{E}_{x}[||S_{\theta}(x, \sigma_i) - \triangledown_{x}log(q_{\sigma_i}(x))||^2]. NCSN: Noise Conditioned Score Network
- Now we got the score approximator as S_{\theta}(x, \sigma_i), how can we use this to generate samples
NoteAnnealed Langevin Dynamics
- We can use Annealed Langevin Dynamics (Sampling with varied noise levels)
- Sample: Noise \sim N(0, \sigma_{T}^2I)
- For each \sigma_i, perform iterative update for K steps
- x <- x + \frac{\alpha_i}{2}S_{\theta}(x, \sigma_i) + \sqrt{\alpha_i}\epsilon
- Obtain final image x_0
- But in above cases, first of all we need to define T, which will be responsible for discretizing the noising process.
- And also if somehow we can make this process evolve in continuous time instead of discrete ones then we can use differential equations efficiently.
From Discrete to Continuous
NoteWiener Processes
- Continuous equivalent of adding incremental gaussian noise
- Stochastic process with following properties:
- w_0 = 0
- w_t - w_s \sim N(0, (t-s)I)
- Independent increment (w_t or w_s doesn’t depend on what happened before)
- Lets make discrete DDPM into continuous one
- x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt\beta_{t} \epsilon
- We want a differential evolution in x as function f everything else.
- x_{t} - x_{t-1}= [\sqrt{1-\beta_{t}}-1]x_{t-1} + \sqrt\beta_{t} \epsilon
- Lets make \beta_t as function of t
- \beta_t = \beta(t)dt
- where
- \beta_t: Noise we add between two discrete steps
- \beta(t): rate of noise at t
- dt: Increment of time
- Now replace \beta_t with \beta(t)dt, which gives
- x_{t} - x_{t-1}= [\sqrt{1-\beta(t)dt}-1]x_{t-1} + \sqrt{\beta(t)dt} \epsilon
- If \triangle_t -> 0, then x_{t} - x_{t-1} = dx (Differential Evolution)
- Using Taylor expansion and wiener process, we can approximate first and second term respectively
- \sqrt{1-\beta(t)dt} \sim 1-\frac{1}{2}\beta(t)dt
- \sqrt{dt} \epsilon \sim dw
- dx = -\frac{1}{2}\beta(t)xdt + \sqrt{\beta(t)}dw
- dx = f(x, t)dt + g(t)dw
- Where
- f(x, t) is drift coefficient which is deterministic
- g(t) is diffusion coefficient which is stochastic
- Where
- SDE Variants
NoteCommon SDE Variants
| Variance Preserving | Variance Exploding |
|---|---|
| Ex: DDPM | Ex:NCSN |
| f(x, t) = -\frac{1}{2}\beta(t)x | f(x, t) = 0 |
| g(t) = \sqrt{\beta(t)} | g(t) = \sqrt{\frac{d[\sigma^2(t)]}{dt}} |
- Generalised training objective
- L_{DSM} = \mathbb{E}_{t, x_0, x_t}[\lambda_t||S_{\theta}(x_t, t) - \triangledown_{x_t}log(p(x_t | x_0))||^2]
- where \triangledown_{x_t}log(p(x_t | x_0)) = -\frac{\epsilon}{\sigma_t}
- L_{DSM} = \mathbb{E}_{t, x_0, x_t}[\lambda_t||S_{\theta}(x_t, t) - \triangledown_{x_t}log(p(x_t | x_0))||^2]
- Training a generalised score model
- Sample x_0(data) \sim p_{data}, \epsilon(noise) \sim N(0, I) and t(timestep) \sim U(0, T)
- Create a noisy sample
- x_t = \alpha_t x_0 + \sigma_t \epsilon. where \alpha_t and \sigma_t can be defined based on the variant of SDE
- Use x_t and t to predict -\frac{\epsilon}{\sigma_t} via S_{\theta}(x_t, t)
- Compute Loss and backpropagate
- L = \lambda_t||S_{\theta}(x_t, t) + \frac{\epsilon}{\sigma_t}||^2
- Inference of trained model
- Forward process: Perturb data into noise
- dx = f(x, t)dt + g(t)dw
- Reverse process: noise back to data
- dx = [f(x, t) - g(t)^2\triangledown_xlog(p_{t}(x)]dt + g(t)d\bar{w}. This derived using Fokker-Planck equation
- Intuition of above equation
- f(x, t): forward drift coefficient
- g(t): forward diffusion coefficient
- d\bar{w}: Wiener process but different than forward one
- g(t)^2\triangledown_xlog(p_{t}(x)): correction to drift coefficient
- Basically in forward process we diffused from data to noise and to go back we need some corrections.
- Forward process: Perturb data into noise
- Sampling using Euler-Maruyama
- Sample Noise i.e. x_T \sim N(0, \sigma_T^2I)
- Use Euler-Maruyama to go through reverse SDE
- x_{t_{i-1}} = x_{t_i} + [f(x_{t_i}, t_i) - g(t_i)^2S_\theta(x_{t_i}, t_i)]\triangle t + g(t_i)\sqrt{\triangle t}\epsilon
- Obtain final sample x_0
- getting \beta(t)
- \beta(t) = \beta_{min} + t * (\beta_{max} - \beta_{min})
- So \int_{0}^{t}\beta(s)ds = \beta_{min}*t + 0.5 * (\beta_{max} - \beta_{min})* t**2
- Limitations of SDE
- Stochastic term as g(t)d\bar{w} means:
- Slower solver: Need 1000-2000 steps
- More source of error: Discretization as well as injected stochastic noise
- Stochastic term as g(t)d\bar{w} means:
Probability Flow ODE
- Hypothetically write SDE as an ODE
- dx = [Something]dt + \phi, where \phi is null
- No stochastic term means:
- Faster solver: leverage mathematical properties to go fast vs slow
- Focused source of error: Now we will have error from discretization only
- Lets derive ODE
- Forward SDE: dx = f(x, t)dt + g(t)dw
- Probability flow of x at t using Fokker-Planck Equation on forward SDE gives:
- \frac{\delta p_t(x)}{\delta t} = -\triangledown(f(x, t)p_t(x)) + \frac{1}{2}g(t)^2 \triangledown p_t(x)
- After applying some algebraic operations, we get continuity equation as:
- \frac{\delta p}{\delta t}(x) = -\triangledown[f(x, t) - \frac{1}{2}g(t)^2 \triangledown_x log(p_t(x))]p_t(x)
- where
- [f(x, t) - \frac{1}{2}g(t)^2 \triangledown_x log(p_t(x))] is velocity of probability flow
- So we got PF-ODE as dx = v(x, t)dt and v(x, t) = [f(x, t) - \frac{1}{2}g(t)^2 \triangledown_x log(p_t(x))]
- where:
- f(x, t) and g(t)^2 are modeling assumptions
- \triangledown_x log(p_t(x)) is score approximated by the model
- where:
- In above formulation, the Probability Flow is preserved between reverse SDE and reverse ODE but the trajectories are different
Implementation
VPSDE
- How to get x_t
- x_t = x_0 * \alpha_t + \epsilon * \sigma_t
- Where:
- \alpha_t = e^{-\frac{1}{2} \int_{0}^{t}\beta(s)ds}
- \sigma_t = \sqrt{1 - e^{-\int_{0}^{t}\beta(s)ds}}
- \beta(t) = \beta_{min} + t*(\beta_{max} - \beta_{min})
- \int_{0}^{t}\beta(s)ds = \beta_{min} * t + 0. 5* (\beta_{max} - \beta_{min}) * t^2
- Drift and Diffusion Coefficient
- f(x,t) = -\frac{1}{2}\beta(t)x
- g(t) = \sqrt{\beta(t)}
sub-VPSDE
- How to get x_t
- x_t = x_0 * \alpha_t + \epsilon * \sigma_t
- Where:
- \alpha_t = e^{-\frac{1}{2} \int_{0}^{t}\beta(s)ds}
- \sigma_t = 1 - e^{-\int_{0}^{t}\beta(s)ds}
- \beta(t) = \beta_{min} + t*(\beta_{max} - \beta_{min})
- \int_{0}^{t}\beta(s)ds = \beta_{min} * t + 0. 5* (\beta_{max} - \beta_{min}) * t^2
- Drift and Diffusion Coefficient
- f(x,t) = -\frac{1}{2}\beta(t)x
- g(t) = \sqrt{\beta(t) (1-e^{-2\int_{0}^{t}\beta(s)ds})}
import torch
import torch.nn as nn
import torch.nn.functional as F
class subVPSDE(nn.Module):
def __init__(self, beta_min, beta_max, eps, score_model):
super().__init__()
self.beta_min = beta_min
self.beta_max = beta_max
self.eps = eps
self.score_model = score_model
def get_integrated_beta_t(self, t):
return self.beta_min*t + 0.5 * (self.beta_max - self.beta_min) * t**2
def get_beta_t(self, t):
return self.beta_min + (self.beta_max - self.beta_min) * t
def get_integrated_alpha_t(self, t):
return torch.exp(-0.5 * self.get_integrated_beta_t(t))
def get_integrated_sigma_t(self, t):
return 1 - torch.exp(-self.get_integrated_beta_t(t))
def perturbation_kernel(self, x_0, t, noise=None):
alpha_t = self.get_integrated_alpha_t(t).view(-1, 1, 1, 1)
sigma_t = self.get_integrated_sigma_t(t).view(-1, 1, 1, 1)
if noise is None:
noise = torch.randn_like(x_0)
x_t = alpha_t * x_0 + sigma_t * noise
return x_t, noise, sigma_t
def sde(self, x, t):
drift = (-0.5 * self.get_beta_t(t) * x).view(-1, 1, 1, 1)
diffusion = torch.sqrt(self.get_beta_t(t)*(1-torch.exp(-2*self.get_integrated_beta_t(t)))).view(-1, 1, 1, 1)
return drift, diffusion
def sample(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - diffusion**2*score) * delta_t + diffusion * torch.sqrt(-delta_t)*torch.randn_like(x_t)
x_t = x_prevt
return x_t
def sample_pf_ode(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - 0.5*diffusion**2*score) * delta_t
x_t = x_prevt
return x_t
def loss(self, batch):
x_0, y = batch
t = self.eps + (1- self.eps)*torch.rand((x_0.shape[0], ))
x_t, noise, sigma_t = self.perturbation_kernel(x_0, t)
s_theta = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
loss = F.mse_loss(s_theta, -noise/sigma_t, reduction='mean')
return lossVESDE
- How to get x_t
- x_t = x_0 + \epsilon * \sigma_t
- Where:
- \sigma_t = \sigma_{min}(\frac{\sigma_{max}}{\sigma_{min}})^t
- Drift and Diffusion Coefficient
- f(x,t) = 0
- g(t) = \sqrt{\frac{d[\sigma^2(t)]}{dt}} = \sigma_t \sqrt{2log\frac{\sigma_{max}}{\sigma_{min}}}
import torch
import torch.nn as nn
import torch.nn.functional as F
class VESDE(nn.Module):
def __init__(self, sigma_min, sigma_max, eps, score_model):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.score_model = score_model
self.eps = eps
def get_sigma_t(self, t):
return self.sigma_min*((self.sigma_max/self.sigma_min) **t)
def perturbation_kernel(self, x_0, t, noise=None):
sigma_t = self.get_sigma_t(t).view(-1, 1, 1, 1)
if noise is None:
noise = torch.randn_like(x_0)
x_t = x_0 + sigma_t * noise
return x_t, noise, sigma_t
def sde(self, x, t):
drift =torch.zeros_like(x)
diffusion = (self.get_sigma_t(t) * torch.sqrt(2*torch.log(self.sigma_max/self.sigma_min))).view(-1, 1, 1, 1)
return drift, diffusion
def sample(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - diffusion**2*score) * delta_t + diffusion * torch.sqrt(-delta_t)*torch.randn_like(x_t)
x_t = x_prevt
return x_t
def sample_pf_ode(self, shape, y, num_steps):
x_t = torch.randn(shape)
ts = torch.linspace(1, self.eps, num_steps)
delta_t = -1/num_steps
for t in ts:
score = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
drift, diffusion = self.sde(x_t, t)
x_prevt = x_t + (drift - 0.5*diffusion**2*score) * delta_t
x_t = x_prevt
return x_t
def loss(self, batch):
x_0, y = batch
t = self.eps + (1- self.eps)*torch.rand((x_0.shape[0], ))
x_t, noise, sigma_t = self.perturbation_kernel(x_0, t)
s_theta = self.score_model(x_t, t.view(-1, 1, 1, 1), y)
loss = F.mse_loss(s_theta, -noise/sigma_t, reduction='mean')
return loss