5. Bayesian inference, Pyro, PyStan and VAEs
In this section, we give some examples on how to work with variational autoencoders and Bayesian inference using Pyro and PyStan.
Take a look at the VAE presentation for some theoretical details on the matter
This tutorial is meant to run using Nvidia CUDA processors. If you don’t have a GPU installed in your computer, you can download this Jupyter notebook and upload it to Google Colab.
[3]:
!pip install pyro-ppl 'pystan<3' numpyro optuna
Collecting pyro-ppl
Downloading pyro_ppl-1.8.2-py3-none-any.whl (722 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 722.7/722.7 kB 155.5 kB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting pystan<3
Downloading pystan-2.19.1.1.tar.gz (16.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.2/16.2 MB 2.4 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01m
Preparing metadata (setup.py) ... done
Collecting numpyro
Downloading numpyro-0.10.1-py3-none-any.whl (292 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 292.7/292.7 kB 2.1 MB/s eta 0:00:00[31m2.2 MB/s eta 0:00:01
Collecting optuna
Downloading optuna-3.0.3-py3-none-any.whl (348 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 348.5/348.5 kB 3.2 MB/s eta 0:00:00 MB/s eta 0:00:01:01
Collecting torch>=1.11.0
Downloading torch-1.13.0-cp310-cp310-manylinux1_x86_64.whl (890.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 890.1/890.1 MB 1.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:10
Collecting tqdm>=4.36
Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 kB 1.5 MB/s eta 0:00:00 MB/s eta 0:00:01
Collecting numpy>=1.7
Downloading numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 2.3 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting opt-einsum>=2.3.2
Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.5/65.5 kB 2.3 MB/s eta 0:00:00
Collecting pyro-api>=0.1.1
Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Collecting Cython!=0.25.1,>=0.22
Downloading Cython-0.29.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 2.7 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting jaxlib>=0.1.65
Downloading jaxlib-0.3.24-cp310-cp310-manylinux2014_x86_64.whl (70.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.0/70.0 MB 2.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting jax>=0.2.13
Downloading jax-0.3.24.tar.gz (1.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 3.1 MB/s eta 0:00:00[31m3.2 MB/s eta 0:00:01
Preparing metadata (setup.py) ... done
Collecting multipledispatch
Downloading multipledispatch-0.6.0-py3-none-any.whl (11 kB)
Collecting cliff
Downloading cliff-4.0.0-py3-none-any.whl (80 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 81.0/81.0 kB 2.4 MB/s eta 0:00:00
Collecting sqlalchemy>=1.3.0
Downloading SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 2.0 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting importlib-metadata<5.0.0
Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)
Collecting PyYAML
Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 682.2/682.2 kB 3.6 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting packaging>=20.0
Using cached packaging-21.3-py3-none-any.whl (40 kB)
Collecting cmaes>=0.8.2
Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)
Collecting colorlog
Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting scipy<1.9.0,>=1.7.0
Downloading scipy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 MB 2.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting alembic>=1.5.0
Downloading alembic-1.8.1-py3-none-any.whl (209 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.8/209.8 kB 3.4 MB/s eta 0:00:00[31m3.7 MB/s eta 0:00:01
Collecting Mako
Downloading Mako-1.2.3-py3-none-any.whl (78 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.7/78.7 kB 3.1 MB/s eta 0:00:00
Collecting zipp>=0.5
Downloading zipp-3.10.0-py3-none-any.whl (6.2 kB)
Collecting typing_extensions
Downloading typing_extensions-4.4.0-py3-none-any.whl (26 kB)
Collecting pyparsing!=3.0.5,>=2.0.2
Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.3/98.3 kB 2.8 MB/s eta 0:00:00
Collecting greenlet!=0.4.17
Downloading greenlet-2.0.0.post0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (536 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.9/536.9 kB 3.5 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting nvidia-cublas-cu11==11.10.3.66
Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━ 185.1/317.1 MB 2.5 MB/s eta 0:00:54
[4]:
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import numpyro
from numpyro import distributions as numdist
from numpyro.infer import MCMC, HMC, NUTS
import jax
import torch.distributions.constraints as constraints
import pystan
from statsmodels.distributions.empirical_distribution import ECDF
from sklearn.neighbors import KernelDensity
import optuna
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In [4], line 7
4 from pyro.optim import Adam
5 from pyro.infer import SVI, Trace_ELBO
----> 7 import numpyro
8 from numpyro import distributions as numdist
9 from numpyro.infer import MCMC, HMC, NUTS
ModuleNotFoundError: No module named 'numpyro'
[ ]:
!nvidia-smi
[ ]:
# Setup some data
theta = 0.6
n = 1000
y = stats.bernoulli.rvs(theta, size=n)
5.1. Get MCMC samples for this model using Stan
[ ]:
#Compile model
model_code = """
data {
int<lower=0> n;
int y[n];
}
parameters {
real<lower=0, upper=1> theta;
}
model {
// likehood:
y ~ bernoulli(theta);
// prior:
theta ~ beta(2.0, 2.0);
}
"""
sm = pystan.StanModel(model_code=model_code)
[ ]:
# Sample model
data_dict = {'y': y, 'n': n}
fit = sm.sampling(data=data_dict, iter=1000, chains=4)
[ ]:
# Extract samples
theta = fit.extract(permuted=True)['theta']
# Print some statistics
print("Some samples:", theta[:10])
print("Mean:", np.mean(theta, axis=0))
print("Standard deviation:", np.std(theta, axis=0))
# Prepare plots
_, ax = plt.subplots(2, 2)
# histograms
# warning: for a caveat about using histograms see
# https://stats.stackexchange.com/a/51753
ax[0, 0].hist(theta, 15)
ax[0, 1].hist(theta, 30)
# Empirical cumulative distribution
ecdf = ECDF(theta)
ax[1, 0].plot(ecdf.x, ecdf.y)
# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)
optuna.logging.set_verbosity(optuna.logging.WARNING)
def kde_fit(data, n_trials=30, cv=None):
if cv is None:
cv = ShuffleSplit(n_splits=1, test_size=0.15, random_state=0)
param_distributions = {
"bandwidth": optuna.distributions.FloatDistribution(1e-5, 1e3, log=True)
}
optuna_search = optuna.integration.OptunaSearchCV(KernelDensity(),
param_distributions, cv=cv, n_trials=n_trials)
optuna_search.fit(np.array(data).reshape(-1, 1))
return optuna_search.best_estimator_
kde_est = kde_fit(theta)
x_kde = np.linspace(0.4, 0.7, 1000).reshape(-1, 1)
y_kde = np.exp(kde_est.best_estimator_.score_samples(x_kde))
ax[1, 1].plot(x_kde, y_kde)
5.2. Get MCMC samples for this model using NumPyro
[ ]:
def model(y):
prior_dist = numdist.Beta(.5, .5)
theta = numpyro.sample('theta', prior_dist)
with numpyro.plate('observe_data', len(y)):
numpyro.sample('obs', numdist.Bernoulli(theta), obs=y)
nuts_kernel = NUTS(model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=300)
mcmc.run(jax.random.PRNGKey(0), y=y)
mcmc_samples = np.array(mcmc.get_samples()['theta'])
[ ]:
print("Some samples:", np.random.choice(mcmc_samples, 4, replace=False))
print("Mean:", mcmc_samples.mean())
print("Standard deviation:", mcmc_samples.std())
5.3. Get replications (new instances of similar to data) from MCMC samples
[ ]:
n_replications = 10000
replications = stats.bernoulli.rvs(np.random.choice(mcmc_samples, n_replications))
bins = np.arange(0, replications.max() + 1.5) - 0.5
_, ax = plt.subplots()
ax.hist(replications, bins)
ax.set_xticks(bins + 0.5)
5.4. Get approximate Bayesian inference for Pyro and stochatisc variational inference
[ ]:
def model(y_tensor):
prior_dist = dist.Beta(torch.Tensor([.5]), torch.Tensor([.5]))
theta = pyro.sample('theta', prior_dist)
with pyro.plate('observe_data'):
pyro.sample('obs', dist.Bernoulli(theta), obs=y_tensor)
def guide(y_tensor):
alpha = pyro.param("alpha", torch.Tensor([1.0]),
constraint=constraints.positive)
beta = pyro.param("beta", torch.Tensor([1.0]),
constraint=constraints.positive)
theta = pyro.sample('theta', dist.Beta(alpha, beta))
[ ]:
y_tensor = torch.Tensor(y)
# set up the optimizer
pyro.clear_param_store()
adam_params = {"lr": 0.2, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_steps = 100
# do gradient steps
for step in range(n_steps):
svi.step(y_tensor)
[ ]:
alpha = pyro.param("alpha").item()
beta = pyro.param("beta").item()
inf_distribution = stats.beta(alpha, beta)
print("Some samples:", inf_distribution.rvs(10))
print("Mean:", inf_distribution.mean())
print("Standard deviation:", inf_distribution.std())
_, axes = plt.subplots(2)
# Plot the posterior
x_svi = np.linspace(0, 1, 10000)
y_svi = inf_distribution.pdf(x_svi)
axes[0].plot(x_svi, y_svi)
# Plot replications
posterior_samples_of_theta = inf_distribution.rvs(n_replications)
replications = stats.bernoulli.rvs(posterior_samples_of_theta)
bins = np.arange(0, replications.max() + 1.5) - 0.5
axes[1].hist(replications, bins)
axes[1].set_xticks(bins + 0.5)
5.5. Using GPU and data subsampling with Pyro
[ ]:
# Setup some data for another model
mu = -0.6
sigma = 1.8
n2 = 10000
y2 = stats.norm.rvs(mu, sigma, size=n2)
y2_tensor = torch.as_tensor(y2, dtype=torch.float32).cuda()
[ ]:
def model(y2_tensor):
# Priors:
prior_dist_mu = dist.Normal(torch.Tensor([0.]).cuda(),
torch.Tensor([1.]).cuda())
mu = pyro.sample('mu', prior_dist_mu)
prior_dist_sigma = dist.Gamma(torch.Tensor([1.]).cuda(),
torch.Tensor([1.]).cuda())
sigma = pyro.sample('sigma', prior_dist_sigma)
# Likelihood:
with pyro.plate('observe_data', size=len(y2_tensor),
subsample_size=5000, use_cuda=True) as ind:
pyro.sample('obs', dist.Normal(mu, sigma),
obs=y2_tensor.index_select(0, ind))
def guide(y2_tensor):
alpha_mu = pyro.param("alpha_mu", torch.Tensor([0.0]).cuda())
beta_mu = pyro.param("beta_mu", torch.Tensor([3.0]).cuda(),
constraint=constraints.positive)
mu = pyro.sample('mu', dist.Normal(alpha_mu, beta_mu))
alpha_sigma = pyro.param("alpha_sigma", torch.Tensor([1.0]).cuda(),
constraint=constraints.positive)
beta_sigma = pyro.param("beta_sigma", torch.Tensor([1.0]).cuda(),
constraint=constraints.positive)
sigma = pyro.sample('sigma', dist.Gamma(alpha_sigma, beta_sigma))
[ ]:
# set up the optimizer
pyro.clear_param_store()
adam_params = {"lr": 0.2, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_steps = 10
# do gradient steps
for step in range(n_steps):
svi.step(y2_tensor)
[ ]:
# Generate replications
alpha_mu = pyro.param("alpha_mu").item()
beta_mu = pyro.param("beta_mu").item()
alpha_sigma = pyro.param("alpha_sigma").item()
beta_sigma = pyro.param("beta_sigma").item()
mu_distribution = stats.norm(alpha_mu, beta)
sigma_distribution = stats.gamma(alpha_sigma, beta_sigma)
mu_samples = mu_distribution.rvs(n_replications)
sigma_samples = sigma_distribution.rvs(n_replications)
data_replications = stats.norm(mu_samples, sigma_samples).rvs()
# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)
params_for_kde_cv = {'bandwidth': np.logspace(-2, 3, 10)}
grid = GridSearchCV(KernelDensity(), params_for_kde_cv, cv=3)
grid.fit(data_replications.reshape(-1, 1))
x_kde = np.linspace(-20, 20, 10000).reshape(-1, 1)
y_kde = np.exp(grid.best_estimator_.score_samples(x_kde))
plt.plot(x_kde, y_kde)
5.6. Variational autoencoders
[ ]:
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim, input_dim):
super(Encoder, self).__init__()
# setup the three linear transformations used
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, x):
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
# define the PyTorch module that parameterizes the
# observation likelihood p(x|z)
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, input_dim):
super(Decoder, self).__init__()
# setup the two linear transformations used
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, input_dim)
self.fc22 = nn.Linear(hidden_dim, input_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.softplus(self.fc1(z))
mu = self.fc21(hidden)
sigma = torch.exp(self.fc22(hidden))
return mu, sigma
# define a PyTorch module for the VAE
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, input_dim,
z_dim=50, hidden_dim=400, use_cuda=False):
super(VAE, self).__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim, input_dim=input_dim)
self.decoder = Decoder(z_dim, hidden_dim, input_dim=input_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
mu, sigma = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Normal(mu, sigma).to_event(1), obs=x)
# return the loc so we can visualize it later
#return loc_img
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
# sample in latent space
z = dist.Normal(z_loc, z_scale).sample()
# decode the image (note we don't sample in image space)
loc_img = self.decoder(z)
return loc_img
def new_instances(self, size=1):
z = stats.norm.rvs(size=(size, self.z_dim))
mu, sigma = self.decoder.forward(torch.as_tensor(z,
device=torch.device('cuda'), dtype=torch.float32))
return stats.norm.rvs(mu.data.cpu().numpy(), sigma.data.cpu().numpy())
[ ]:
# clear param store
pyro.clear_param_store()
no_instances = 20000
input_dim = 3
mu = stats.norm.rvs(size=input_dim)
# Generate a positive definite matrix
sigma = stats.norm.rvs(size=(input_dim, input_dim))
sigma[np.triu_indices(input_dim)] = 0
sigma += np.diag(np.abs(stats.norm.rvs(size=input_dim)))
sigma = np.matmul(sigma.transpose(), sigma) # inverse cholesky decomposition
dataset = stats.multivariate_normal.rvs(mu, sigma, size=no_instances)
dataset = torch.as_tensor(dataset, dtype=torch.float32)
dataset = TensorDataset(dataset)
train_loader = DataLoader(dataset, batch_size=1000, shuffle=True,
num_workers=1, pin_memory=True, drop_last=False)
# setup the VAE
vae = VAE(use_cuda=True, input_dim=input_dim)
adam_args = {"lr": 0.001}
optimizer = Adam(adam_args)
# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
train_elbo = []
for epoch in range(100):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for x, in train_loader:
x = x.cuda()
epoch_loss += svi.step(x)
# report training diagnostics
if not epoch % 10:
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
train_elbo.append(total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" %
(epoch, total_epoch_loss_train))
[ ]:
# Generating new instances (replications) from the trained VAE
new_instances = vae.new_instances(100000)
print("True means")
print(mu)
print("Empirical means of replications:")
print(new_instances.mean(0))
print("----------------------------------------")
print("True covariance matrix")
print(sigma)
print("Empirical covariance matrix of replications:")
print(np.cov(new_instances, rowvar=False))