3. Image classification with neural networks
Here we use Pytorch and a Resnext neural networks to classify images on the Kaggle fruits dataset.
[1]:
!pip install -q -U "pytorch_lightning<1.3" mlflow
|████████████████████████████████| 849kB 28.3MB/s
|████████████████████████████████| 184kB 39.2MB/s
|████████████████████████████████| 276kB 48.1MB/s
Building wheel for PyYAML (setup.py) ... done
[2]:
import os
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import pickle
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import mlflow
from pytorch_lightning.loggers import MLFlowLogger
from torch.utils.data import random_split, DataLoader
from torch import get_num_threads
torch.hub.load('pytorch/vision:v0.9.1', 'resnext50_32x4d')
import torchvision
from torchvision import transforms
%matplotlib inline
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.9.1
[7]:
# if on colab, copy archive.zip from google drive to working dir
!cp /gdrive/MyDrive/datasets/kaggle-fruits/archive.zip .
[8]:
!unzip -q -u -d ./data ./archive.zip
[3]:
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = torchvision.datasets.ImageFolder(
'./data/fruits-360/Training/',
transform=preprocess
)
test_dataset = torchvision.datasets.ImageFolder(
'./data/fruits-360/Test/',
transform=preprocess
)
print(len(train_dataset))
print(train_dataset[0][0].shape)
67692
torch.Size([3, 100, 100])
3.1. Using Pytorch
[ ]:
class NeuralNetEstimator():
def __init__(self, lr=0.001, random_state=None,
train_on_a_small_subset_of_data=False):
self.lr = lr
self.random_state = random_state
self.train_on_a_small_subset_of_data = train_on_a_small_subset_of_data
def fit(self, dataset):
self.dataset = dataset
if self.random_state is not None:
torch.manual_seed(self.random_state)
self.net = torch.hub.load('pytorch/vision:v0.9.1', 'resnext50_32x4d')
self.nclasses = len(np.unique(dataset.targets))
self.net.fc = nn.Linear(net.fc.in_features, self.nclasses)
cuda = torch.cuda.is_available()
if cuda:
self.net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
print("Starting optimization.")
self.train_losses = []
self.val_losses = []
# db split
train_idx = torch.randperm(len(dataset))
val_idx = train_idx[:len(dataset)//10]
train_idx = train_idx[len(dataset)//10:]
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)
if self.train_on_a_small_subset_of_data:
train_dataset = torch.utils.data.Subset(train_dataset, range(200))
val_dataset = torch.utils.data.Subset(val_dataset, range(100))
# initial values for early stopping decision state parameters
last_val_loss = np.inf
es_tries = 0
for epoch in range(100_000):
try:
# network training step
self.net.train()
dataset_loader_train = torch.utils.data.DataLoader(
train_dataset, batch_size=100, shuffle=True,
pin_memory=cuda, drop_last=True
)
batch_losses = []
for batch_inputv, batch_target in dataset_loader_train:
if cuda:
batch_inputv = batch_inputv.cuda()
batch_target = batch_target.cuda()
optimizer.zero_grad()
output = self.net(batch_inputv)
batch_loss = criterion(output, batch_target)
batch_loss.backward()
optimizer.step()
batch_losses.append(batch_loss.item())
loss = np.mean(batch_losses)
self.train_losses.append(loss)
print('\rTrain loss', np.round(loss.item(), 2), end='')
# network evaluation step
self.net.eval()
with torch.no_grad():
dataset_loader_val = torch.utils.data.DataLoader(
val_dataset, batch_size=100, shuffle=False,
pin_memory=cuda, drop_last=False,
)
batch_losses = []
batch_sizes = []
for batch_inputv, batch_target in dataset_loader_val:
if cuda:
batch_inputv = batch_inputv.cuda()
batch_target = batch_target.cuda()
optimizer.zero_grad()
output = self.net(batch_inputv)
batch_loss = criterion(output, batch_target)
batch_losses.append(batch_loss.item())
batch_sizes.append(len(batch_inputv))
loss = np.average(batch_losses, weights=batch_sizes)
self.val_losses.append(loss)
print(' | Validation loss', np.round(loss.item(), 2),
'in epoch', epoch + 1, end='')
# Decisions based on the evaluated values
if loss < last_val_loss:
best_state_dict = self.net.state_dict()
best_state_dict = pickle.dumps(best_state_dict)
es_tries = 0
last_val_loss = loss
else:
if es_tries in [20, 40]:
self.net.load_state_dict(pickle.loads(best_state_dict))
if es_tries >= 60:
self.net.load_state_dict(pickle.loads(best_state_dict))
break
es_tries += 1
print(' | es_tries', es_tries, end='', flush=True)
except KeyboardInterrupt:
if epoch > 0:
print("\nKeyboard interrupt detected.",
"Switching weights to lowest validation loss",
"and exiting")
self.net.load_state_dict(pickle.loads(best_state_dict))
break
print(f"\nOptimization finished in {epoch+1} epochs.")
def get_loss(self, dataset):
criterion = nn.CrossEntropyLoss()
cuda = torch.cuda.is_available()
self.net.eval()
with torch.no_grad():
dataset_loader_val = torch.utils.data.DataLoader(
dataset, batch_size=100, shuffle=False,
pin_memory=cuda, drop_last=False,
)
cv_batch_losses = []
zo_batch_losses = []
batch_sizes = []
for batch_inputv, batch_target in dataset_loader_val:
if cuda:
batch_inputv = batch_inputv.cuda()
batch_target = batch_target.cuda()
output = self.net(batch_inputv)
cv_batch_loss = criterion(output, batch_target)
zo_batch_loss = (torch.argmax(output, 1) != batch_target).cpu()
zo_batch_loss = np.array(zo_batch_loss).mean()
cv_batch_losses.append(cv_batch_loss.item())
zo_batch_losses.append(zo_batch_loss.item())
batch_sizes.append(len(batch_inputv))
cv_loss = np.average(cv_batch_losses, weights=batch_sizes)
zo_loss = np.average(zo_batch_losses, weights=batch_sizes)
return cv_loss, zo_loss
nn_estimator = NeuralNetEstimator(
random_state=0,
#train_on_a_small_subset_of_data=True,
)
nn_estimator.fit(train_dataset)
# save results
with open('/gdrive/MyDrive/datasets/kaggle-fruits/model.pkl', 'wb') as f:
pickle.dump(nn_estimator, f)
loss_on_test = nn_estimator.get_loss(test_dataset)
print(f"Loss on test dataset for default parameters: {loss_on_test}")
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.9.1
Starting optimization.
Train loss 0.01 | Validation loss 0.0 in epoch 17 | es_tries 6
Keyboard interrupt detected. Switching weights to lowest validation loss and exiting
Optimization finished in 18 epochs.
Loss on test dataset for default parameters: (0.08706481696378363, 0.01564703808180536)
3.2. Using Pytorch lightning
[4]:
class LitNN(pl.LightningModule):
def __init__(self, n_classification_labels, lr=0.01, dropout=0.5):
super().__init__()
self.net = torch.hub.load('pytorch/vision:v0.9.1', 'resnext50_32x4d')
self.net.fc = self._initialize_layer(nn.Linear(self.net.fc.in_features, n_classification_labels))
self.lr = lr
def forward(self, x):
x = self.net(x)
return x
def _initialize_layer(self, layer):
nn.init.constant_(layer.bias, 0)
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(layer.weight, gain=gain)
return layer
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(self, train_batch, batch_idx):
inputv, target = train_batch
output = self.forward(inputv)
loss = F.cross_entropy(output, target)
self.log('train_loss_ce', loss.item())
return loss
def test_validation_step(self, batch, batch_idx, name):
inputv, target = batch
output = self.forward(inputv)
loss_ce = F.cross_entropy(output, target).item()
loss_zo = (torch.argmax(output, 1) != target)+0.
loss_zo = loss_zo.mean().item()
self.log(f'{name}_loss_ce', loss_ce)
self.log(f'{name}_loss_zo', loss_zo)
def validation_step(self, val_batch, batch_idx):
self.test_validation_step(val_batch, batch_idx, 'val')
def test_step(self, test_batch, batch_idx):
self.test_validation_step(test_batch, batch_idx, 'test')
[5]:
class DataModule(pl.LightningDataModule):
def __init__(self, train_val_dataset, test_dataset,
batch_size = 50, train_val_split_seed=0):
super().__init__()
self.train_val_dataset = train_val_dataset
self.test_dataset = test_dataset
self.batch_size = min(batch_size, len(train_val_dataset.targets))
self.num_workers = get_num_threads()
self.train_val_split_seed = train_val_split_seed
def setup(self, stage):
if stage == 'fit':
generator = torch.Generator().manual_seed(self.train_val_split_seed)
full_dataset = self.train_val_dataset
partitions = [len(full_dataset) - len(full_dataset)//10, len(full_dataset) // 10]
full_dataset = random_split(full_dataset, partitions,
generator=generator)
self.train_dataset, self.val_dataset = full_dataset
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True,
shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers = self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size,
num_workers = self.num_workers)
[ ]:
datamodule = DataModule(train_dataset, test_dataset)
n_classification_labels = len(np.unique(train_dataset.targets))
smodel = LitNN(n_classification_labels=n_classification_labels)
early_stop_callback = EarlyStopping(
monitor='val_loss_ce',
min_delta=0.00,
patience=30,
verbose=False,
mode='min'
)
logger = MLFlowLogger(
experiment_name="fruits_classification",
tracking_uri="file:./mlruns",
)
trainer = pl.Trainer(
precision=32,
gpus=torch.cuda.device_count(),
tpu_cores=None,
logger=logger,
val_check_interval=0.1, # do validation check 10 times for each epoch
auto_scale_batch_size=True,
auto_lr_find=True,
callbacks=early_stop_callback,
max_epochs = 100,
)
# find "best" batch_size and lr
trainer.tune(smodel, datamodule = datamodule)
# fit smodel
trainer.fit(smodel, datamodule = datamodule)
# test smodel
trainer.test(smodel, datamodule = datamodule)
smodel.trainer.callback_metrics
[8]:
smodel.trainer.callback_metrics
[8]:
{'test_loss_ce': tensor(0.0790),
'test_loss_zo': tensor(0.0150),
'train_loss_ce': tensor(5.5617e-05, device='cuda:0'),
'val_loss_ce': tensor(4.0881e-05),
'val_loss_zo': tensor(0.)}