5. Bayesian inference, Pyro, PyStan and VAEs

In this section, we give some examples on how to work with variational autoencoders and Bayesian inference using Pyro and PyStan.

Take a look at the VAE presentation for some theoretical details on the matter

This tutorial is meant to run using Nvidia CUDA processors. If you don’t have a GPU installed in your computer, you can download this Jupyter notebook and upload it to Google Colab.

[3]:
!pip install pyro-ppl 'pystan<3' numpyro optuna
Collecting pyro-ppl
  Downloading pyro_ppl-1.8.2-py3-none-any.whl (722 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 722.7/722.7 kB 155.5 kB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting pystan<3
  Downloading pystan-2.19.1.1.tar.gz (16.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.2/16.2 MB 2.4 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01m
  Preparing metadata (setup.py) ... done
Collecting numpyro
  Downloading numpyro-0.10.1-py3-none-any.whl (292 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 292.7/292.7 kB 2.1 MB/s eta 0:00:00[31m2.2 MB/s eta 0:00:01
Collecting optuna
  Downloading optuna-3.0.3-py3-none-any.whl (348 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 348.5/348.5 kB 3.2 MB/s eta 0:00:00 MB/s eta 0:00:01:01
Collecting torch>=1.11.0
  Downloading torch-1.13.0-cp310-cp310-manylinux1_x86_64.whl (890.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 890.1/890.1 MB 1.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:10
Collecting tqdm>=4.36
  Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 kB 1.5 MB/s eta 0:00:00 MB/s eta 0:00:01
Collecting numpy>=1.7
  Downloading numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 2.3 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting opt-einsum>=2.3.2
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.5/65.5 kB 2.3 MB/s eta 0:00:00
Collecting pyro-api>=0.1.1
  Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Collecting Cython!=0.25.1,>=0.22
  Downloading Cython-0.29.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 2.7 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting jaxlib>=0.1.65
  Downloading jaxlib-0.3.24-cp310-cp310-manylinux2014_x86_64.whl (70.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.0/70.0 MB 2.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting jax>=0.2.13
  Downloading jax-0.3.24.tar.gz (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 3.1 MB/s eta 0:00:00[31m3.2 MB/s eta 0:00:01
  Preparing metadata (setup.py) ... done
Collecting multipledispatch
  Downloading multipledispatch-0.6.0-py3-none-any.whl (11 kB)
Collecting cliff
  Downloading cliff-4.0.0-py3-none-any.whl (80 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 81.0/81.0 kB 2.4 MB/s eta 0:00:00
Collecting sqlalchemy>=1.3.0
  Downloading SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 2.0 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting importlib-metadata<5.0.0
  Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)
Collecting PyYAML
  Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 682.2/682.2 kB 3.6 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting packaging>=20.0
  Using cached packaging-21.3-py3-none-any.whl (40 kB)
Collecting cmaes>=0.8.2
  Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)
Collecting colorlog
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting scipy<1.9.0,>=1.7.0
  Downloading scipy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 MB 2.2 MB/s eta 0:00:00m eta 0:00:01[36m0:00:01
Collecting alembic>=1.5.0
  Downloading alembic-1.8.1-py3-none-any.whl (209 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.8/209.8 kB 3.4 MB/s eta 0:00:00[31m3.7 MB/s eta 0:00:01
Collecting Mako
  Downloading Mako-1.2.3-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.7/78.7 kB 3.1 MB/s eta 0:00:00
Collecting zipp>=0.5
  Downloading zipp-3.10.0-py3-none-any.whl (6.2 kB)
Collecting typing_extensions
  Downloading typing_extensions-4.4.0-py3-none-any.whl (26 kB)
Collecting pyparsing!=3.0.5,>=2.0.2
  Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.3/98.3 kB 2.8 MB/s eta 0:00:00
Collecting greenlet!=0.4.17
  Downloading greenlet-2.0.0.post0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.9/536.9 kB 3.5 MB/s eta 0:00:00m eta 0:00:010:01:01
Collecting nvidia-cublas-cu11==11.10.3.66
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 185.1/317.1 MB 2.5 MB/s eta 0:00:54
[4]:
import pyro
import pyro.distributions as dist

from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

import numpyro
from numpyro import distributions as numdist
from numpyro.infer import MCMC, HMC, NUTS
import jax

import torch.distributions.constraints as constraints

import pystan

from statsmodels.distributions.empirical_distribution import ECDF
from sklearn.neighbors import KernelDensity

import optuna

import numpy as np
import scipy.stats as stats
import pandas as pd

import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In [4], line 7
      4 from pyro.optim import Adam
      5 from pyro.infer import SVI, Trace_ELBO
----> 7 import numpyro
      8 from numpyro import distributions as numdist
      9 from numpyro.infer import MCMC, HMC, NUTS

ModuleNotFoundError: No module named 'numpyro'
[ ]:
!nvidia-smi
[ ]:
# Setup some data
theta = 0.6
n = 1000
y = stats.bernoulli.rvs(theta, size=n)

5.1. Get MCMC samples for this model using Stan

[ ]:
#Compile model

model_code = """
data {
    int<lower=0> n;
    int y[n];
}
parameters {
    real<lower=0, upper=1> theta;
}
model {
    // likehood:
    y ~ bernoulli(theta);

    // prior:
    theta ~ beta(2.0, 2.0);
}
"""

sm = pystan.StanModel(model_code=model_code)
[ ]:
# Sample model
data_dict = {'y': y, 'n': n}
fit = sm.sampling(data=data_dict, iter=1000, chains=4)
[ ]:
# Extract samples
theta = fit.extract(permuted=True)['theta']

# Print some statistics
print("Some samples:", theta[:10])
print("Mean:", np.mean(theta, axis=0))
print("Standard deviation:", np.std(theta, axis=0))

# Prepare plots
_, ax = plt.subplots(2, 2)

# histograms
# warning: for a caveat about using histograms see
# https://stats.stackexchange.com/a/51753
ax[0, 0].hist(theta, 15)
ax[0, 1].hist(theta, 30)

# Empirical cumulative distribution
ecdf = ECDF(theta)
ax[1, 0].plot(ecdf.x, ecdf.y)

# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)
optuna.logging.set_verbosity(optuna.logging.WARNING)
def kde_fit(data, n_trials=30, cv=None):
    if cv is None:
        cv = ShuffleSplit(n_splits=1, test_size=0.15, random_state=0)
    param_distributions = {
        "bandwidth": optuna.distributions.FloatDistribution(1e-5, 1e3, log=True)
    }
    optuna_search = optuna.integration.OptunaSearchCV(KernelDensity(),
        param_distributions, cv=cv, n_trials=n_trials)
    optuna_search.fit(np.array(data).reshape(-1, 1))
    return optuna_search.best_estimator_

kde_est = kde_fit(theta)
x_kde = np.linspace(0.4, 0.7, 1000).reshape(-1, 1)
y_kde = np.exp(kde_est.best_estimator_.score_samples(x_kde))
ax[1, 1].plot(x_kde, y_kde)

5.2. Get MCMC samples for this model using NumPyro

[ ]:
def model(y):
    prior_dist = numdist.Beta(.5, .5)
    theta = numpyro.sample('theta', prior_dist)
    with numpyro.plate('observe_data', len(y)):
        numpyro.sample('obs', numdist.Bernoulli(theta), obs=y)

nuts_kernel = NUTS(model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=300)
mcmc.run(jax.random.PRNGKey(0), y=y)
mcmc_samples = np.array(mcmc.get_samples()['theta'])
[ ]:
print("Some samples:", np.random.choice(mcmc_samples, 4, replace=False))
print("Mean:", mcmc_samples.mean())
print("Standard deviation:", mcmc_samples.std())

5.3. Get replications (new instances of similar to data) from MCMC samples

[ ]:
n_replications = 10000

replications = stats.bernoulli.rvs(np.random.choice(mcmc_samples, n_replications))
bins = np.arange(0, replications.max() + 1.5) - 0.5
_, ax = plt.subplots()
ax.hist(replications, bins)
ax.set_xticks(bins + 0.5)

5.4. Get approximate Bayesian inference for Pyro and stochatisc variational inference

[ ]:
def model(y_tensor):
    prior_dist = dist.Beta(torch.Tensor([.5]), torch.Tensor([.5]))
    theta = pyro.sample('theta', prior_dist)
    with pyro.plate('observe_data'):
        pyro.sample('obs', dist.Bernoulli(theta), obs=y_tensor)

def guide(y_tensor):
    alpha = pyro.param("alpha", torch.Tensor([1.0]),
                       constraint=constraints.positive)
    beta = pyro.param("beta", torch.Tensor([1.0]),
                       constraint=constraints.positive)
    theta = pyro.sample('theta', dist.Beta(alpha, beta))
[ ]:
y_tensor = torch.Tensor(y)

# set up the optimizer
pyro.clear_param_store()
adam_params = {"lr": 0.2, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 100
# do gradient steps
for step in range(n_steps):
    svi.step(y_tensor)
[ ]:
alpha = pyro.param("alpha").item()
beta = pyro.param("beta").item()

inf_distribution = stats.beta(alpha, beta)
print("Some samples:", inf_distribution.rvs(10))
print("Mean:", inf_distribution.mean())
print("Standard deviation:", inf_distribution.std())

_, axes = plt.subplots(2)

# Plot the posterior
x_svi = np.linspace(0, 1, 10000)
y_svi = inf_distribution.pdf(x_svi)
axes[0].plot(x_svi, y_svi)

# Plot replications
posterior_samples_of_theta = inf_distribution.rvs(n_replications)

replications = stats.bernoulli.rvs(posterior_samples_of_theta)
bins = np.arange(0, replications.max() + 1.5) - 0.5
axes[1].hist(replications, bins)
axes[1].set_xticks(bins + 0.5)

5.5. Using GPU and data subsampling with Pyro

[ ]:
# Setup some data for another model
mu = -0.6
sigma = 1.8

n2 = 10000
y2 = stats.norm.rvs(mu, sigma, size=n2)
y2_tensor = torch.as_tensor(y2, dtype=torch.float32).cuda()
[ ]:
def model(y2_tensor):
    # Priors:
    prior_dist_mu = dist.Normal(torch.Tensor([0.]).cuda(),
                                torch.Tensor([1.]).cuda())
    mu = pyro.sample('mu', prior_dist_mu)

    prior_dist_sigma = dist.Gamma(torch.Tensor([1.]).cuda(),
                                  torch.Tensor([1.]).cuda())
    sigma = pyro.sample('sigma', prior_dist_sigma)

    # Likelihood:
    with pyro.plate('observe_data', size=len(y2_tensor),
        subsample_size=5000, use_cuda=True) as ind:
        pyro.sample('obs', dist.Normal(mu, sigma),
            obs=y2_tensor.index_select(0, ind))


def guide(y2_tensor):
    alpha_mu = pyro.param("alpha_mu", torch.Tensor([0.0]).cuda())
    beta_mu = pyro.param("beta_mu", torch.Tensor([3.0]).cuda(),
        constraint=constraints.positive)
    mu = pyro.sample('mu', dist.Normal(alpha_mu, beta_mu))

    alpha_sigma = pyro.param("alpha_sigma", torch.Tensor([1.0]).cuda(),
        constraint=constraints.positive)
    beta_sigma = pyro.param("beta_sigma", torch.Tensor([1.0]).cuda(),
        constraint=constraints.positive)
    sigma = pyro.sample('sigma', dist.Gamma(alpha_sigma, beta_sigma))
[ ]:
# set up the optimizer
pyro.clear_param_store()
adam_params = {"lr": 0.2, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 10
# do gradient steps
for step in range(n_steps):
    svi.step(y2_tensor)
[ ]:
# Generate replications

alpha_mu = pyro.param("alpha_mu").item()
beta_mu = pyro.param("beta_mu").item()
alpha_sigma = pyro.param("alpha_sigma").item()
beta_sigma = pyro.param("beta_sigma").item()

mu_distribution = stats.norm(alpha_mu, beta)
sigma_distribution = stats.gamma(alpha_sigma, beta_sigma)

mu_samples = mu_distribution.rvs(n_replications)
sigma_samples = sigma_distribution.rvs(n_replications)

data_replications = stats.norm(mu_samples, sigma_samples).rvs()

# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)
params_for_kde_cv = {'bandwidth': np.logspace(-2, 3, 10)}
grid = GridSearchCV(KernelDensity(), params_for_kde_cv, cv=3)
grid.fit(data_replications.reshape(-1, 1))
x_kde = np.linspace(-20, 20, 10000).reshape(-1, 1)
y_kde = np.exp(grid.best_estimator_.score_samples(x_kde))
plt.plot(x_kde, y_kde)

5.6. Variational autoencoders

[ ]:
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, input_dim):
        super(Encoder, self).__init__()
        # setup the three linear transformations used
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, x):
        # then compute the hidden units
        hidden = self.softplus(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale


# define the PyTorch module that parameterizes the
# observation likelihood p(x|z)
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, input_dim):
        super(Decoder, self).__init__()
        # setup the two linear transformations used
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, input_dim)
        self.fc22 = nn.Linear(hidden_dim, input_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.softplus(self.fc1(z))

        mu = self.fc21(hidden)
        sigma = torch.exp(self.fc22(hidden))
        return mu, sigma


# define a PyTorch module for the VAE
class VAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, input_dim,
        z_dim=50, hidden_dim=400, use_cuda=False):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim, input_dim=input_dim)
        self.decoder = Decoder(z_dim, hidden_dim, input_dim=input_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            mu, sigma = self.decoder.forward(z)
            # score against actual images
            pyro.sample("obs", dist.Normal(mu, sigma).to_event(1), obs=x)
            # return the loc so we can visualize it later
            #return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img

    def new_instances(self, size=1):
         z = stats.norm.rvs(size=(size, self.z_dim))
         mu, sigma = self.decoder.forward(torch.as_tensor(z,
             device=torch.device('cuda'), dtype=torch.float32))
         return stats.norm.rvs(mu.data.cpu().numpy(), sigma.data.cpu().numpy())
[ ]:
# clear param store
pyro.clear_param_store()

no_instances = 20000
input_dim = 3
mu = stats.norm.rvs(size=input_dim)

# Generate a positive definite matrix
sigma = stats.norm.rvs(size=(input_dim, input_dim))
sigma[np.triu_indices(input_dim)] = 0
sigma += np.diag(np.abs(stats.norm.rvs(size=input_dim)))
sigma = np.matmul(sigma.transpose(), sigma) # inverse cholesky decomposition

dataset = stats.multivariate_normal.rvs(mu, sigma, size=no_instances)
dataset = torch.as_tensor(dataset, dtype=torch.float32)
dataset = TensorDataset(dataset)
train_loader = DataLoader(dataset, batch_size=1000, shuffle=True,
     num_workers=1, pin_memory=True, drop_last=False)

# setup the VAE
vae = VAE(use_cuda=True, input_dim=input_dim)

adam_args = {"lr": 0.001}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
for epoch in range(100):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, in train_loader:
        x = x.cuda()
        epoch_loss += svi.step(x)

    # report training diagnostics
    if not epoch % 10:
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
             (epoch, total_epoch_loss_train))
[ ]:
# Generating new instances (replications) from the trained VAE
new_instances = vae.new_instances(100000)

print("True means")
print(mu)
print("Empirical means of replications:")
print(new_instances.mean(0))

print("----------------------------------------")

print("True covariance matrix")
print(sigma)
print("Empirical covariance matrix of replications:")
print(np.cov(new_instances, rowvar=False))