18. Local linear neural networks

This notebook contains an implementation of Local linear neural networks (https://arxiv.org/abs/1910.05206) on PyTorch Lightning

[ ]:
#!pip install pytorch_lightning optuna mlflow
[ ]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader
import pickle
from copy import deepcopy
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import tempfile
import os
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.preprocessing import StandardScaler
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.loggers import MLFlowLogger

%matplotlib inline
[ ]:
np.random.seed(1)
beta = stats.norm().rvs([10, 1])
train_inputv = stats.norm().rvs([700, 10])
train_target = np.matmul(train_inputv, beta)
train_target = train_target

test_inputv = stats.norm().rvs([200, 10])
test_target = np.matmul(test_inputv, beta)
test_target = test_target

cutpoints = np.quantile(train_target, [.1, .7, .9])

train_target_label = sum([0+(train_target > cutpoint) for cutpoint in cutpoints],0)
train_target_label = train_target_label.ravel()

test_target_label = sum([0+(test_target > cutpoint) for cutpoint in cutpoints],0)
test_target_label = test_target_label.ravel()
[ ]:
# For comparison
clf = ExtraTreesClassifier(n_estimators=1000)
clf.fit(train_inputv, train_target_label)
(clf.predict(test_inputv) != test_target_label).mean()
[ ]:
scaler = StandardScaler().fit(train_inputv)
train_inputv = scaler.transform(train_inputv)
test_inputv = scaler.transform(test_inputv)
[ ]:
class LitNN(pl.LightningModule):
    def __init__(self, n_classification_labels, nfeatures,
                 penalization_thetas,
                 penalization_variable_theta0,
                 hsizes = [100, 50], lr=0.01, weight_decay=0, batch_size=50,
                 dropout=0.5, varying_theta0=True, fixed_theta0=True):
        super().__init__()

        assert n_classification_labels != 1
        self.penalization_thetas = penalization_thetas
        self.penalization_variable_theta0 = penalization_variable_theta0
        self.nfeatures = nfeatures
        self.weight_decay = weight_decay
        self.lr = lr
        self.hsizes = hsizes
        self.batch_size = batch_size
        self.dropl = nn.Dropout(p=dropout)
        self.varying_theta0 = varying_theta0
        self.fixed_theta0 = fixed_theta0
        self.n_classification_labels = n_classification_labels

        if self.fixed_theta0:
            self.theta0f = nn.Parameter(torch.FloatTensor([.0]))

        # linear hidden layers
        llayers = []
        normllayers = []
        next_input_size = nfeatures
        for hsize in hsizes:
            llayers.append(self._initialize_layer(
                nn.Linear(next_input_size, hsize)
            ))
            normllayers.append(nn.BatchNorm1d(hsize))
            next_input_size = hsize
        self.llayers = nn.ModuleList(llayers)
        self.normllayers = nn.ModuleList(normllayers)

        out_dim = nfeatures + varying_theta0
        if self.n_classification_labels:
            out_dim = out_dim * self.n_classification_labels
        self.fc_last = nn.Linear(next_input_size, out_dim)
        self._initialize_layer(self.fc_last)

        self.elu = nn.ELU()

    def forward(self, x):
        for i in range(len(self.llayers)):
            fc = self.llayers[i]
            fcn = self.normllayers[i]
            x = fcn(self.elu(fc(x)))
            x = self.dropl(x)

        thetas = self.fc_last(x)
        if self.n_classification_labels:
            thetas = thetas.view(
                thetas.shape[0],
                -1,
                self.n_classification_labels
                )
        if self.varying_theta0:
            theta0v = thetas[:, :1]
            thetas = thetas[:, 1:]
        else:
            theta0v = None
            thetas = thetas

        if self.fixed_theta0:
            theta0f = self.theta0f
        else:
            theta0f = None

        return theta0v, theta0f, thetas

    def _initialize_layer(self, layer):
        nn.init.constant_(layer.bias, 0)
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(layer.weight, gain=gain)
        return layer

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer

    def loss_function(self, theta0v, thetas, inputv, target, output):
        with torch.enable_grad():
            criterion = F.cross_entropy if self.n_classification_labels else F.mse_loss
            loss = criterion(output, target)

            if not self.n_classification_labels:
                thetas = thetas[..., None]
                if theta0v is not None:
                    theta0v = theta0v[..., None]

            if self.penalization_thetas:
                for i in range(thetas.shape[1]):
                    for j in range(thetas.shape[2]):
                        grads, = torch.autograd.grad(
                            thetas[:, i, j].sum(), inputv,
                            create_graph=True,
                            )
                        loss = (grads**2).mean(0).sum()
                        if i or j:
                            loss2 = loss2 + loss
                        else:
                            loss2 = loss

                loss2 = loss2 * self.penalization_thetas
                loss = loss + loss2

            if (self.penalization_variable_theta0
                and theta0v is not None):
                for i in range(theta0v.shape[1]):
                    grads, = torch.autograd.grad(
                            theta0v[i].sum(), inputv,
                            create_graph=True,
                            )

                    loss = (grads ** 2).mean(0).sum()
                    loss = loss * (
                        self.penalization_variable_theta0)
                    if i:
                            loss3 = loss3 + loss
                    else:
                            loss3 = loss
                loss = loss + loss3

            return loss

    def training_val_step(self, batch, batch_idx, train):
        with torch.enable_grad():
            inputv, target = batch

            inputv.requires_grad_(True)
            theta0v, theta0f, thetas = self.forward(inputv)
            assert theta0v.requires_grad
            assert thetas.requires_grad

            inputs_ext = inputv
            if self.n_classification_labels:
                inputs_ext = inputv[..., None]

            output = (thetas * inputs_ext).sum(1, True)

            if theta0v is not None:
                output = output + theta0v
            if theta0f is not None:
                output = output + theta0f

            if self.n_classification_labels:
                output = output.transpose(1,2)
            loss = self.loss_function(theta0v, thetas, inputv, target, output)

            return loss


    def training_step(self, train_batch, batch_idx, log=True):
        loss = self.training_val_step(train_batch, batch_idx, train=True)
        self.log('train_loss', loss.item())
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss = self.training_val_step(val_batch, batch_idx, train=False)
        self.log('val_loss', loss.item())

    def test_step(self, test_batch, batch_idx):
        pass


    def _predict_all(self, x_pred, grad_out, out_probs):
        with torch.autograd.set_grad_enabled(grad_out):
            self.eval()
            inputv = torch.as_tensor(x_pred)
            inputv = inputv.type_as(next(self.fc_last.parameters()))

            if grad_out:
                inputv = inputv.requires_grad_(True)

            theta0v, theta0f, thetas = self.forward(inputv)

            inputs_ext = inputv
            if self.n_classification_labels:
                inputs_ext = inputv[..., None]

            output_pred = (thetas * inputs_ext).sum(1, True)
            if theta0v is not None:
                output_pred = output_pred + theta0v
            if theta0f is not None:
                output_pred = output_pred + theta0f

            if self.n_classification_labels:
                output_pred = output_pred[:, 0]
                if out_probs:
                    output_pred = F.softmax(output_pred, 1)
                else:
                    output_pred = (torch.max(output_pred, 1,
                        True).indices)

            output_pred = output_pred.data.cpu().numpy()

            # Derivative penalization start
            if grad_out:
                for i in range(thetas.shape[1]):
                    grads_this, = torch.autograd.grad(
                        thetas[:, i].sum(), inputv,
                        retain_graph=True,
                        )

                    grads_this = grads_this[:, :, None].cpu()
                    if i:
                        grads1 = torch.cat((grads1, grads_this), 2)
                    else:
                        grads1 = grads_this

                grads1 = grads1 ** 2.
                grads1 = grads1.numpy()

                if theta0v is not None:
                    grads2, = torch.autograd.grad(
                            theta0v.sum(), inputv,
                            )

                    grads2 = grads2 ** 2.
                    grads2 = grads2.cpu().numpy()
                else:
                    grads2 = None

                return output_pred, grads1, grads2
            # Derivative penalization end

            return output_pred

    def predict_proba(self, x_pred, grad_out=False):
        """
        Predict probabilities of y (if in classification mode).

        Parameters
        ----------
        x_pred : array
            Matrix of features
        grad_out :
            If True, then will output a tuple where the first element is
            the predicted value of y, the second is a numpy array with
            squared gradients of regarding the thetas of each variable,
            and the third is squared gradients of regarding the thetas
            of the varying theta0 (None if self.varying_theta0 is
            False). Note that in case of the second element of the tuple
            the array has shape (no_instances, no_features, no_features)
            where the second dimension refers to denominator of the
            derivative and the third dimension refers to the numerator
            of the derivative). You can concatenate both with:
            `np.concatenate((res[2][:,:,None], res[1]), 2)`
        """
        return self._predict_all(x_pred, grad_out, True)

    def predict(self, x_pred, grad_out=False):
        """
        Predict y.

        Parameters
        ----------
        x_pred : array
            Matrix of features
        grad_out :
            If True, then will output a tuple where the first element is
            the predicted value of y, the second is a numpy array with
            squared gradients of regarding the thetas of each variable,
            and the third is squared gradients of regarding the thetas
            of the varying theta0 (None if self.varying_theta0 is
            False). Note that the second element of the tuple is an
            array of shape (no_instances, no_features, no_features)
            where the second dimension refers to denominator of the
            derivative and the third dimension refers to the numerator
            of the derivative). You can concatenate both with:
            `np.concatenate((res[2][:,:,None], res[1]), 2)`
        """
        return self._predict_all(x_pred, grad_out, False)

    def get_thetas(self, x_pred, scaler=None):
        with torch.no_grad():
            self.eval()
            inputv = torch.as_tensor(x_pred)
            inputv = inputv.type_as(next(self.fc_last.parameters()))

            theta0v, theta0f, thetas = self.forward(inputv)

            if scaler is not None:
                scale = torch.as_tensor(scaler.scale_)
                scale = scale.type_as(next(self.fc_last.parameters()))
                mean = torch.as_tensor(scaler.mean_)
                mean = mean.type_as(next(self.fc_last.parameters()))

                if self.n_classification_labels:
                    scale = scale[:, None]
                    mean = mean[:, None]

                thetas = thetas / scale

                if theta0v is None:
                    theta0v = 0
                theta0v = - (mean * thetas).sum(1, True) + theta0v
                theta0v = theta0v.data.cpu().numpy()

                if theta0f is not None:
                    theta0f = theta0f.data.cpu().numpy()
                thetas = thetas.data.cpu().numpy()
            else:
                if theta0v is not None:
                    theta0v = theta0v.data.cpu().numpy()
                if theta0f is not None:
                    theta0f = theta0f.data.cpu().numpy()

                thetas = thetas.data.cpu().numpy()

            return theta0v, theta0f, thetas
[ ]:
class DataModule(pl.LightningDataModule):
    def __init__(self, train_inputv, train_target,
                 n_classification_labels,
                 batch_size = 50, num_workers=2):
        super().__init__()
        self.batch_size = min(batch_size, len(train_target))
        self.train_inputv = torch.as_tensor(train_inputv, dtype=torch.float32)
        y_dtype = torch.long if n_classification_labels else torch.float32
        self.train_target = torch.as_tensor(train_target, dtype=y_dtype)
        self.num_workers = num_workers

    def setup(self, stage):
        full_dataset = TensorDataset(self.train_inputv, self.train_target.reshape(-1)[:,None])

        partitions = [len(full_dataset) - len(full_dataset)//10, len(full_dataset) // 10]
        full_dataset = torch.utils.data.random_split(full_dataset, partitions,
                                                         generator=torch.Generator().manual_seed(42))
        self.train_dataset, self.val_dataset = full_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True, shuffle=True,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

18.1. Classification

[ ]:
#%%capture cap_out --no-stderr

datamodule = DataModule(train_inputv, train_target_label, n_classification_labels=4)

smodel = LitNN(
    nfeatures=train_inputv.shape[1],
    n_classification_labels=4,
    penalization_variable_theta0=0.1,
    penalization_thetas=0.1,
)

early_stop_callback = EarlyStopping(
   monitor='val_loss',
   min_delta=0.00,
   patience=30,
   verbose=False,
   mode='min'
)

# use MLFlow as logger if available, see other options at
# https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html
# you can start MLFLow server with:
# mlflow server --backend-store-uri=./ml-runs
logger = MLFlowLogger(
    experiment_name="Default",
    tracking_uri="file:./mlruns"
)

trainer = pl.Trainer(
                     precision=32,
                     gpus=torch.cuda.device_count(),
                     tpu_cores=None,
                     logger=logger,
                     val_check_interval=0.25, # do validation check 4 times for each epoch
                     callbacks=early_stop_callback,
                     max_epochs = 100,
                     progress_bar_refresh_rate=0,
                    )

# fit smodel
trainer.fit(smodel, datamodule = datamodule)

# check if smodel if is pickable
_ = pickle.dumps(smodel)

smodel.trainer.callback_metrics
[ ]:
preds = smodel.predict(test_inputv)
probs = smodel.predict_proba(test_inputv)
probs, grads_thetav, grads_theta0 = smodel.predict_proba(test_inputv, grad_out=True)
theta0v, theta0f, thetas = smodel.get_thetas(test_inputv)
theta0v_descaled, theta0f_descaled, thetas_descaled = smodel.get_thetas(test_inputv, scaler=scaler)

18.2. Hyperparameters optimization using Optuna

[ ]:
try:
    study
except NameError:
    study = optuna.create_study(direction="minimize", pruner=optuna.pruners.SuccessiveHalvingPruner())
try:
    tempdir
except NameError:
    tempdir = tempfile.TemporaryDirectory().name
    os.mkdir(tempdir)
print(tempdir)
[ ]:
#%%capture cap_out2
def objective(trial: optuna.trial.Trial) -> float:
    hsize1 = 50#trial.suggest_int("hsize1", 10, 1000)
    hsize2 = 50#trial.suggest_int("hsize2", 10, max(20, 1000 - hsize1))
    batch_size = 100#trial.suggest_int("batch_size", 50, 400)
    lr = 0.001#trial.suggest_float("lr", 1e-5, 0.1)
    dropout = 0.1#trial.suggest_float("dropout", 0.0, 0.5)
    weight_decay = 0#trial.suggest_float("weight_decay", 0.0, 0.01)
    penalization_variable_theta0 = trial.suggest_float("penalization_variable_theta0", 0.0, 1.0)
    penalization_thetas = trial.suggest_float("penalization_thetas", 0.0, 1.0)

    model = LitNN(
        nfeatures=train_inputv.shape[1],
        n_classification_labels=4,
        penalization_variable_theta0=penalization_variable_theta0,
        penalization_thetas=penalization_thetas,
        hsizes = [hsize1, hsize2], lr=lr, batch_size=batch_size, dropout=dropout,
        weight_decay = weight_decay,
    )
    datamodule = DataModule(train_inputv, train_target_label, n_classification_labels=4, batch_size=batch_size)
    early_stop_callback = EarlyStopping(
       monitor='val_loss',
       min_delta=0.00,
       patience=30,
       verbose=False,
       mode='min'
    )

    logger = MLFlowLogger(
        experiment_name="Default",
        tracking_uri="file:./mlruns"
    )

    trainer = pl.Trainer(
                         precision=32,
                         gpus=torch.cuda.device_count(),
                         logger=logger,
                         val_check_interval=0.25,
                         callbacks=[early_stop_callback,
                                    PyTorchLightningPruningCallback(trial, monitor="val_loss_ce")
                                   ],
                         max_epochs = 50,
                         #progress_bar_refresh_rate = 0,
                        )
    trainer.fit(model, datamodule = datamodule)

    trainer.logger.log_hyperparams(trial.params)

    with open(f"{os.path.join(tempdir, str(trial.number))}.pkl", "wb") as f:
        pickle.dump(model, f)

    return trainer.callback_metrics["val_loss"].item()

study.optimize(objective, n_trials=10000, timeout=600)
[ ]:
# save on study on disk
with open(f"{os.path.join(tempdir, 'study')}.pkl", "wb") as f:
    pickle.dump(study, f)

print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:", study.best_params)

with open(f"{os.path.join(tempdir, str(study.best_trial.number))}.pkl", "rb") as f:
    best_model = pickle.load(f)
[ ]:
trials_summary = sorted(study.trials, key=lambda x: np.inf if x.value is None else x.value)
trials_summary = [dict(trial_number=trial.number, loss=trial.value, **trial.params) for trial in trials_summary]
trials_summary = pd.DataFrame(trials_summary)
trials_summary.iloc[:200]
[ ]:
preds = best_model.predict(test_inputv)
probs = best_model.predict_proba(test_inputv)
probs, grads_thetav, grads_theta0 = best_model.predict_proba(test_inputv, grad_out=True)
theta0v, theta0f, thetas = best_model.get_thetas(test_inputv)
theta0v_descaled, theta0f_descaled, thetas_descaled = best_model.get_thetas(test_inputv, scaler=scaler)

18.3. Regression

[ ]:
datamodule = DataModule(train_inputv, train_target, n_classification_labels=0)

smodel = LitNN(
    nfeatures=train_inputv.shape[1],
    n_classification_labels=0, # this defines a regression
    penalization_variable_theta0=0.1,
    penalization_thetas=0.1,
)

early_stop_callback = EarlyStopping(
   monitor='val_loss',
   min_delta=0.00,
   patience=30,
   verbose=False,
   mode='min'
)

logger = MLFlowLogger(
    experiment_name="Default",
    tracking_uri="file:./mlruns"
)

trainer = pl.Trainer(
                     precision=32,
                     gpus=torch.cuda.device_count(),
                     tpu_cores=None,
                     logger=logger,
                     val_check_interval=0.25, # do validation check 4 times for each epoch
                     callbacks=early_stop_callback,
                     max_epochs = 100,
                     progress_bar_refresh_rate = 0,
                    )

# fit smodel
trainer.fit(smodel, datamodule = datamodule)

# check if smodel if is pickable
_ = pickle.dumps(smodel)

smodel.trainer.callback_metrics
[ ]:
preds = smodel.predict(test_inputv)
preds, grads_thetav, grads_theta0 = smodel.predict(test_inputv, grad_out=True)
theta0v, theta0f, thetas = smodel.get_thetas(test_inputv)
theta0v_descaled, theta0f_descaled, thetas_descaled = smodel.get_thetas(test_inputv, scaler=scaler)