{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Bayesian inference, Pyro, PyStan and VAEs\n", "\n", "In this section, we give some examples on how to work with variational autoencoders and Bayesian inference using Pyro and PyStan.\n", "\n", "[Take a look at the VAE presentation for some theoretical details on the matter](https://marcoinacio.com/downloads/vae.pdf)\n", "\n", "This tutorial is meant to run using Nvidia CUDA processors. If you don't have a GPU installed in your computer, you can download this Jupyter notebook and upload it to [Google Colab](https://colab.research.google.com)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qOt-sj9JSJEU", "outputId": "37533146-5669-4c17-a512-5f43de1971d2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting pyro-ppl\n", " Downloading pyro_ppl-1.8.2-py3-none-any.whl (722 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m722.7/722.7 kB\u001b[0m \u001b[31m155.5 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting pystan<3\n", " Downloading pystan-2.19.1.1.tar.gz (16.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.2/16.2 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0mm\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting numpyro\n", " Downloading numpyro-0.10.1-py3-none-any.whl (292 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m292.7/292.7 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting optuna\n", " Downloading optuna-3.0.3-py3-none-any.whl (348 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m348.5/348.5 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hCollecting torch>=1.11.0\n", " Downloading torch-1.13.0-cp310-cp310-manylinux1_x86_64.whl (890.1 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m890.1/890.1 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:10\u001b[0m\n", "\u001b[?25hCollecting tqdm>=4.36\n", " Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting numpy>=1.7\n", " Downloading numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.1/17.1 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting opt-einsum>=2.3.2\n", " Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.5/65.5 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting pyro-api>=0.1.1\n", " Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)\n", "Collecting Cython!=0.25.1,>=0.22\n", " Downloading Cython-0.29.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.9/1.9 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hCollecting jaxlib>=0.1.65\n", " Downloading jaxlib-0.3.24-cp310-cp310-manylinux2014_x86_64.whl (70.0 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.0/70.0 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting jax>=0.2.13\n", " Downloading jax-0.3.24.tar.gz (1.1 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting multipledispatch\n", " Downloading multipledispatch-0.6.0-py3-none-any.whl (11 kB)\n", "Collecting cliff\n", " Downloading cliff-4.0.0-py3-none-any.whl (80 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.0/81.0 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sqlalchemy>=1.3.0\n", " Downloading SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting importlib-metadata<5.0.0\n", " Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)\n", "Collecting PyYAML\n", " Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m682.2/682.2 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hCollecting packaging>=20.0\n", " Using cached packaging-21.3-py3-none-any.whl (40 kB)\n", "Collecting cmaes>=0.8.2\n", " Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)\n", "Collecting colorlog\n", " Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)\n", "Collecting scipy<1.9.0,>=1.7.0\n", " Downloading scipy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.2 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.2/42.2 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting alembic>=1.5.0\n", " Downloading alembic-1.8.1-py3-none-any.whl (209 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.8/209.8 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting Mako\n", " Downloading Mako-1.2.3-py3-none-any.whl (78 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.7/78.7 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting zipp>=0.5\n", " Downloading zipp-3.10.0-py3-none-any.whl (6.2 kB)\n", "Collecting typing_extensions\n", " Downloading typing_extensions-4.4.0-py3-none-any.whl (26 kB)\n", "Collecting pyparsing!=3.0.5,>=2.0.2\n", " Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.3/98.3 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting greenlet!=0.4.17\n", " Downloading greenlet-2.0.0.post0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (536 kB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.9/536.9 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cublas-cu11==11.10.3.66\n", " Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)\n", "\u001b[2K \u001b[38;2;249;38;114m━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[38;2;249;38;114m╸\u001b[0m\u001b[38;5;237m━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m185.1/317.1 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:54\u001b[0m" ] } ], "source": [ "!pip install pyro-ppl 'pystan<3' numpyro optuna" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "k6mZkupmSSRx" }, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'numpyro'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [4], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyro\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moptim\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Adam\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyro\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfer\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SVI, Trace_ELBO\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpyro\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpyro\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m distributions \u001b[38;5;28;01mas\u001b[39;00m numdist\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpyro\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfer\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MCMC, HMC, NUTS\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpyro'" ] } ], "source": [ "import pyro\n", "import pyro.distributions as dist\n", "\n", "from pyro.optim import Adam\n", "from pyro.infer import SVI, Trace_ELBO\n", "\n", "import numpyro\n", "from numpyro import distributions as numdist\n", "from numpyro.infer import MCMC, HMC, NUTS\n", "import jax\n", "\n", "import torch.distributions.constraints as constraints\n", "\n", "import pystan\n", "\n", "from statsmodels.distributions.empirical_distribution import ECDF\n", "from sklearn.neighbors import KernelDensity\n", "\n", "import optuna\n", "\n", "import numpy as np\n", "import scipy.stats as stats\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "import torch\n", "from torch import Tensor\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, TensorDataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u9KyyZvSJwpF", "outputId": "3778b422-01ab-4a5f-f08c-aa020d4af847" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iuXefrGrXSpa" }, "outputs": [], "source": [ "# Setup some data\n", "theta = 0.6\n", "n = 1000\n", "y = stats.bernoulli.rvs(theta, size=n)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xs1w4UJwYj49" }, "source": [ "## Get MCMC samples for this model using Stan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vJ5RL6khBrMg" }, "outputs": [], "source": [ "#Compile model\n", "\n", "model_code = \"\"\"\n", "data {\n", " int n;\n", " int y[n];\n", "}\n", "parameters {\n", " real theta;\n", "}\n", "model {\n", " // likehood:\n", " y ~ bernoulli(theta);\n", " \n", " // prior:\n", " theta ~ beta(2.0, 2.0);\n", "}\n", "\"\"\"\n", "\n", "sm = pystan.StanModel(model_code=model_code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c_BJJVHhGJ_h" }, "outputs": [], "source": [ "# Sample model\n", "data_dict = {'y': y, 'n': n}\n", "fit = sm.sampling(data=data_dict, iter=1000, chains=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 359 }, "id": "AxBYSv1QDwqB", "outputId": "c5f409c8-290a-428a-cc6a-8305dfa26ef7" }, "outputs": [], "source": [ "# Extract samples\n", "theta = fit.extract(permuted=True)['theta']\n", "\n", "# Print some statistics\n", "print(\"Some samples:\", theta[:10])\n", "print(\"Mean:\", np.mean(theta, axis=0))\n", "print(\"Standard deviation:\", np.std(theta, axis=0))\n", "\n", "# Prepare plots\n", "_, ax = plt.subplots(2, 2)\n", "\n", "# histograms\n", "# warning: for a caveat about using histograms see\n", "# https://stats.stackexchange.com/a/51753\n", "ax[0, 0].hist(theta, 15)\n", "ax[0, 1].hist(theta, 30)\n", "\n", "# Empirical cumulative distribution\n", "ecdf = ECDF(theta)\n", "ax[1, 0].plot(ecdf.x, ecdf.y)\n", "\n", "# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)\n", "optuna.logging.set_verbosity(optuna.logging.WARNING)\n", "def kde_fit(data, n_trials=30, cv=None):\n", " if cv is None:\n", " cv = ShuffleSplit(n_splits=1, test_size=0.15, random_state=0)\n", " param_distributions = {\n", " \"bandwidth\": optuna.distributions.FloatDistribution(1e-5, 1e3, log=True)\n", " }\n", " optuna_search = optuna.integration.OptunaSearchCV(KernelDensity(),\n", " param_distributions, cv=cv, n_trials=n_trials)\n", " optuna_search.fit(np.array(data).reshape(-1, 1))\n", " return optuna_search.best_estimator_\n", "\n", "kde_est = kde_fit(theta)\n", "x_kde = np.linspace(0.4, 0.7, 1000).reshape(-1, 1)\n", "y_kde = np.exp(kde_est.best_estimator_.score_samples(x_kde))\n", "ax[1, 1].plot(x_kde, y_kde)" ] }, { "cell_type": "markdown", "metadata": { "id": "gHrplkFWY3ld" }, "source": [ "## Get MCMC samples for this model using NumPyro" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CZZUnlYYYRuj", "outputId": "8ec85885-817e-4777-b9e5-3e52a808ca8c" }, "outputs": [], "source": [ "def model(y):\n", " prior_dist = numdist.Beta(.5, .5)\n", " theta = numpyro.sample('theta', prior_dist)\n", " with numpyro.plate('observe_data', len(y)):\n", " numpyro.sample('obs', numdist.Bernoulli(theta), obs=y)\n", "\n", "nuts_kernel = NUTS(model, adapt_step_size=True)\n", "mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=300)\n", "mcmc.run(jax.random.PRNGKey(0), y=y)\n", "mcmc_samples = np.array(mcmc.get_samples()['theta'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "liU25QagUqms", "outputId": "b12cab46-42dc-4924-8d67-597b8c73e73a" }, "outputs": [], "source": [ "print(\"Some samples:\", np.random.choice(mcmc_samples, 4, replace=False))\n", "print(\"Mean:\", mcmc_samples.mean())\n", "print(\"Standard deviation:\", mcmc_samples.std())" ] }, { "cell_type": "markdown", "metadata": { "id": "OMfd4QjNnATb" }, "source": [ "## Get replications (new instances of similar to data) from MCMC samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 320 }, "id": "o5BKWMvHnJK6", "outputId": "69bc8267-a5d3-4108-dc83-93e4413d8972" }, "outputs": [], "source": [ "n_replications = 10000\n", "\n", "replications = stats.bernoulli.rvs(np.random.choice(mcmc_samples, n_replications))\n", "bins = np.arange(0, replications.max() + 1.5) - 0.5\n", "_, ax = plt.subplots()\n", "ax.hist(replications, bins)\n", "ax.set_xticks(bins + 0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "pSFfXqFzY6rt" }, "source": [ "## Get approximate Bayesian inference for Pyro and stochatisc variational inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3OEK2MFJUvEx" }, "outputs": [], "source": [ "def model(y_tensor):\n", " prior_dist = dist.Beta(torch.Tensor([.5]), torch.Tensor([.5]))\n", " theta = pyro.sample('theta', prior_dist)\n", " with pyro.plate('observe_data'):\n", " pyro.sample('obs', dist.Bernoulli(theta), obs=y_tensor)\n", "\n", "def guide(y_tensor):\n", " alpha = pyro.param(\"alpha\", torch.Tensor([1.0]),\n", " constraint=constraints.positive)\n", " beta = pyro.param(\"beta\", torch.Tensor([1.0]),\n", " constraint=constraints.positive)\n", " theta = pyro.sample('theta', dist.Beta(alpha, beta))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yHyBUAE9ax1R" }, "outputs": [], "source": [ "y_tensor = torch.Tensor(y)\n", "\n", "# set up the optimizer\n", "pyro.clear_param_store()\n", "adam_params = {\"lr\": 0.2, \"betas\": (0.90, 0.999)}\n", "optimizer = Adam(adam_params)\n", "\n", "# setup the inference algorithm\n", "svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n", "\n", "n_steps = 100\n", "# do gradient steps\n", "for step in range(n_steps):\n", " svi.step(y_tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 393 }, "id": "2il2SgT5aumi", "outputId": "f1ba4fe2-8a99-4370-a69d-44d1e917a3af" }, "outputs": [], "source": [ "alpha = pyro.param(\"alpha\").item()\n", "beta = pyro.param(\"beta\").item()\n", " \n", "inf_distribution = stats.beta(alpha, beta)\n", "print(\"Some samples:\", inf_distribution.rvs(10))\n", "print(\"Mean:\", inf_distribution.mean())\n", "print(\"Standard deviation:\", inf_distribution.std())\n", "\n", "_, axes = plt.subplots(2)\n", "\n", "# Plot the posterior\n", "x_svi = np.linspace(0, 1, 10000)\n", "y_svi = inf_distribution.pdf(x_svi)\n", "axes[0].plot(x_svi, y_svi)\n", "\n", "# Plot replications\n", "posterior_samples_of_theta = inf_distribution.rvs(n_replications)\n", "\n", "replications = stats.bernoulli.rvs(posterior_samples_of_theta)\n", "bins = np.arange(0, replications.max() + 1.5) - 0.5\n", "axes[1].hist(replications, bins)\n", "axes[1].set_xticks(bins + 0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "Skv0GLRriN5m" }, "source": [ "## Using GPU and data subsampling with Pyro " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eHr1j4ZeiSXa" }, "outputs": [], "source": [ "# Setup some data for another model\n", "mu = -0.6\n", "sigma = 1.8\n", "\n", "n2 = 10000\n", "y2 = stats.norm.rvs(mu, sigma, size=n2)\n", "y2_tensor = torch.as_tensor(y2, dtype=torch.float32).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eT_xpEVMiffT" }, "outputs": [], "source": [ "def model(y2_tensor):\n", " # Priors:\n", " prior_dist_mu = dist.Normal(torch.Tensor([0.]).cuda(),\n", " torch.Tensor([1.]).cuda())\n", " mu = pyro.sample('mu', prior_dist_mu)\n", " \n", " prior_dist_sigma = dist.Gamma(torch.Tensor([1.]).cuda(),\n", " torch.Tensor([1.]).cuda())\n", " sigma = pyro.sample('sigma', prior_dist_sigma)\n", " \n", " # Likelihood:\n", " with pyro.plate('observe_data', size=len(y2_tensor),\n", " subsample_size=5000, use_cuda=True) as ind:\n", " pyro.sample('obs', dist.Normal(mu, sigma),\n", " obs=y2_tensor.index_select(0, ind))\n", " \n", "\n", "def guide(y2_tensor):\n", " alpha_mu = pyro.param(\"alpha_mu\", torch.Tensor([0.0]).cuda())\n", " beta_mu = pyro.param(\"beta_mu\", torch.Tensor([3.0]).cuda(),\n", " constraint=constraints.positive)\n", " mu = pyro.sample('mu', dist.Normal(alpha_mu, beta_mu))\n", "\n", " alpha_sigma = pyro.param(\"alpha_sigma\", torch.Tensor([1.0]).cuda(),\n", " constraint=constraints.positive)\n", " beta_sigma = pyro.param(\"beta_sigma\", torch.Tensor([1.0]).cuda(),\n", " constraint=constraints.positive)\n", " sigma = pyro.sample('sigma', dist.Gamma(alpha_sigma, beta_sigma))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HDEiFzsFkyPO" }, "outputs": [], "source": [ "# set up the optimizer\n", "pyro.clear_param_store()\n", "adam_params = {\"lr\": 0.2, \"betas\": (0.90, 0.999)}\n", "optimizer = Adam(adam_params)\n", "\n", "# setup the inference algorithm\n", "svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n", "\n", "n_steps = 10\n", "# do gradient steps\n", "for step in range(n_steps):\n", " svi.step(y2_tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 283 }, "id": "fiq-YmJdmOU4", "outputId": "ab6a6332-9d10-4d03-e8e1-43c0ca65a96e" }, "outputs": [], "source": [ "# Generate replications\n", "\n", "alpha_mu = pyro.param(\"alpha_mu\").item()\n", "beta_mu = pyro.param(\"beta_mu\").item()\n", "alpha_sigma = pyro.param(\"alpha_sigma\").item()\n", "beta_sigma = pyro.param(\"beta_sigma\").item()\n", " \n", "mu_distribution = stats.norm(alpha_mu, beta)\n", "sigma_distribution = stats.gamma(alpha_sigma, beta_sigma)\n", "\n", "mu_samples = mu_distribution.rvs(n_replications)\n", "sigma_samples = sigma_distribution.rvs(n_replications)\n", "\n", "data_replications = stats.norm(mu_samples, sigma_samples).rvs()\n", "\n", "# Density estimation using KDE (with tuning parameter chosen by 3 fold CV)\n", "params_for_kde_cv = {'bandwidth': np.logspace(-2, 3, 10)}\n", "grid = GridSearchCV(KernelDensity(), params_for_kde_cv, cv=3)\n", "grid.fit(data_replications.reshape(-1, 1))\n", "x_kde = np.linspace(-20, 20, 10000).reshape(-1, 1)\n", "y_kde = np.exp(grid.best_estimator_.score_samples(x_kde))\n", "plt.plot(x_kde, y_kde)" ] }, { "cell_type": "markdown", "metadata": { "id": "rA1jNJOJ92fu" }, "source": [ "## Variational autoencoders" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lk1-rqM-HiCu" }, "outputs": [], "source": [ "# define the PyTorch module that parameterizes the\n", "# diagonal gaussian distribution q(z|x)\n", "class Encoder(nn.Module):\n", " def __init__(self, z_dim, hidden_dim, input_dim):\n", " super(Encoder, self).__init__()\n", " # setup the three linear transformations used\n", " self.fc1 = nn.Linear(input_dim, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, z_dim)\n", " self.fc22 = nn.Linear(hidden_dim, z_dim)\n", " # setup the non-linearities\n", " self.softplus = nn.Softplus()\n", "\n", " def forward(self, x):\n", " # then compute the hidden units\n", " hidden = self.softplus(self.fc1(x))\n", " # then return a mean vector and a (positive) square root covariance\n", " # each of size batch_size x z_dim\n", " z_loc = self.fc21(hidden)\n", " z_scale = torch.exp(self.fc22(hidden))\n", " return z_loc, z_scale\n", "\n", "\n", "# define the PyTorch module that parameterizes the\n", "# observation likelihood p(x|z)\n", "class Decoder(nn.Module):\n", " def __init__(self, z_dim, hidden_dim, input_dim):\n", " super(Decoder, self).__init__()\n", " # setup the two linear transformations used\n", " self.fc1 = nn.Linear(z_dim, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, input_dim)\n", " self.fc22 = nn.Linear(hidden_dim, input_dim)\n", " # setup the non-linearities\n", " self.softplus = nn.Softplus()\n", "\n", " def forward(self, z):\n", " # define the forward computation on the latent z\n", " # first compute the hidden units\n", " hidden = self.softplus(self.fc1(z))\n", " \n", " mu = self.fc21(hidden)\n", " sigma = torch.exp(self.fc22(hidden))\n", " return mu, sigma\n", "\n", "\n", "# define a PyTorch module for the VAE\n", "class VAE(nn.Module):\n", " # by default our latent space is 50-dimensional\n", " # and we use 400 hidden units\n", " def __init__(self, input_dim,\n", " z_dim=50, hidden_dim=400, use_cuda=False):\n", " super(VAE, self).__init__()\n", " # create the encoder and decoder networks\n", " self.encoder = Encoder(z_dim, hidden_dim, input_dim=input_dim)\n", " self.decoder = Decoder(z_dim, hidden_dim, input_dim=input_dim)\n", "\n", " if use_cuda:\n", " # calling cuda() here will put all the parameters of\n", " # the encoder and decoder networks into gpu memory\n", " self.cuda()\n", " self.use_cuda = use_cuda\n", " self.z_dim = z_dim\n", "\n", " # define the model p(x|z)p(z)\n", " def model(self, x):\n", " # register PyTorch module `decoder` with Pyro\n", " pyro.module(\"decoder\", self.decoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # setup hyperparameters for prior p(z)\n", " z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)\n", " z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)\n", " # sample from prior (value will be sampled by guide when computing the ELBO)\n", " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", " # decode the latent code z\n", " mu, sigma = self.decoder.forward(z)\n", " # score against actual images\n", " pyro.sample(\"obs\", dist.Normal(mu, sigma).to_event(1), obs=x)\n", " # return the loc so we can visualize it later\n", " #return loc_img\n", "\n", " # define the guide (i.e. variational distribution) q(z|x)\n", " def guide(self, x):\n", " # register PyTorch module `encoder` with Pyro\n", " pyro.module(\"encoder\", self.encoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # use the encoder to get the parameters used to define q(z|x)\n", " z_loc, z_scale = self.encoder.forward(x)\n", " # sample the latent code z\n", " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", "\n", " # define a helper function for reconstructing images\n", " def reconstruct_img(self, x):\n", " # encode image x\n", " z_loc, z_scale = self.encoder(x)\n", " # sample in latent space\n", " z = dist.Normal(z_loc, z_scale).sample()\n", " # decode the image (note we don't sample in image space)\n", " loc_img = self.decoder(z)\n", " return loc_img\n", " \n", " def new_instances(self, size=1):\n", " z = stats.norm.rvs(size=(size, self.z_dim))\n", " mu, sigma = self.decoder.forward(torch.as_tensor(z,\n", " device=torch.device('cuda'), dtype=torch.float32))\n", " return stats.norm.rvs(mu.data.cpu().numpy(), sigma.data.cpu().numpy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6wQWH15L91jW", "outputId": "f1d16e09-89ea-4124-8487-b80049591e00" }, "outputs": [], "source": [ "# clear param store\n", "pyro.clear_param_store()\n", "\n", "no_instances = 20000\n", "input_dim = 3\n", "mu = stats.norm.rvs(size=input_dim)\n", "\n", "# Generate a positive definite matrix\n", "sigma = stats.norm.rvs(size=(input_dim, input_dim))\n", "sigma[np.triu_indices(input_dim)] = 0\n", "sigma += np.diag(np.abs(stats.norm.rvs(size=input_dim)))\n", "sigma = np.matmul(sigma.transpose(), sigma) # inverse cholesky decomposition\n", "\n", "dataset = stats.multivariate_normal.rvs(mu, sigma, size=no_instances)\n", "dataset = torch.as_tensor(dataset, dtype=torch.float32)\n", "dataset = TensorDataset(dataset)\n", "train_loader = DataLoader(dataset, batch_size=1000, shuffle=True,\n", " num_workers=1, pin_memory=True, drop_last=False)\n", "\n", "# setup the VAE\n", "vae = VAE(use_cuda=True, input_dim=input_dim)\n", "\n", "adam_args = {\"lr\": 0.001}\n", "optimizer = Adam(adam_args)\n", "\n", "# setup the inference algorithm\n", "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n", "\n", "train_elbo = []\n", "for epoch in range(100):\n", " # initialize loss accumulator\n", " epoch_loss = 0.\n", " # do a training epoch over each mini-batch x returned\n", " # by the data loader\n", " for x, in train_loader:\n", " x = x.cuda()\n", " epoch_loss += svi.step(x)\n", "\n", " # report training diagnostics\n", " if not epoch % 10:\n", " normalizer_train = len(train_loader.dataset)\n", " total_epoch_loss_train = epoch_loss / normalizer_train\n", " train_elbo.append(total_epoch_loss_train)\n", " print(\"[epoch %03d] average training loss: %.4f\" %\n", " (epoch, total_epoch_loss_train))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RgPXUrLyQSZv", "outputId": "3b5a7e3f-5b88-4672-a592-b7f03aa44aa8" }, "outputs": [], "source": [ "# Generating new instances (replications) from the trained VAE\n", "new_instances = vae.new_instances(100000)\n", "\n", "print(\"True means\")\n", "print(mu)\n", "print(\"Empirical means of replications:\")\n", "print(new_instances.mean(0))\n", "\n", "print(\"----------------------------------------\")\n", "\n", "print(\"True covariance matrix\")\n", "print(sigma)\n", "print(\"Empirical covariance matrix of replications:\")\n", "print(np.cov(new_instances, rowvar=False))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 1 }