4. Text classification

Here we use Word2Vec to generate word embeddings and LSTM to classify them from a DBPedia dataset.

The Word2Vec is trained using gensim and the LSTM side is done using PyTorch.

[1]:
#!pip install pytorch_lightning mlflow torchtext gensim
[1]:
import pickle
import os
import itertools

import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt

import gensim
import torchtext

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

%matplotlib inline
/home/marco/.local/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.
  warnings.warn(msg)

4.1. Word2Vec

[2]:
text_corpus = [
    "humans can watch TV",
    "houses can not watch TV",
    "humans can talk about houses",
    "refrigerators can not talk",
    "humans think about strange things",
    "beds do not think about strange things",
    "I see interesting things about the next TV",
    "Do not think lightly about humans",
    "The next major might think about things in a different way about the TV",
    "This place is not unsafe for humans",
    "Yellow houses are good for humans",
    "Those are not safes houses at all",
    "Humans have houses",
    "Houses are blue",
    "Do not think about houses",
    "There is one TV inside those houses"
]

class ProcessedText:
    def __init__(self, text):
        self.text = text

    def __iter__(self):
        for line in self.text:
            yield gensim.utils.simple_preprocess(line)
[3]:
wv_model = gensim.models.Word2Vec(vector_size=20, window=5, min_count=1, workers=4)
wv_model.build_vocab(ProcessedText(text_corpus), update=False)
wv_model.train(ProcessedText(text_corpus), total_examples=wv_model.corpus_count, epochs=100)
[3]:
(2074, 9400)
[4]:
# example of vector for a word

wv_model.wv['humans']
[4]:
array([-0.05358737,  0.02408984, -0.04244986, -0.02269858,  0.00658333,
       -0.00125561, -0.02724232,  0.06634317,  0.00881555,  0.04800051,
       -0.03760158,  0.02025081, -0.01671168, -0.00771837,  0.04606754,
       -0.02332198,  0.03516515, -0.03691765, -0.03200763,  0.03627926],
      dtype=float32)
[5]:
# find most similar words given a set of words

wv_model.wv.most_similar(['humans', 'can'], topn=10)
[5]:
[('not', 0.4467436969280243),
 ('strange', 0.4427088797092438),
 ('way', 0.37376558780670166),
 ('at', 0.3642667531967163),
 ('the', 0.36400458216667175),
 ('next', 0.3263437747955322),
 ('major', 0.22556430101394653),
 ('think', 0.20717817544937134),
 ('about', 0.20420345664024353),
 ('things', 0.177797332406044)]
[6]:
wv_model.wv.most_similar('can', topn=10)
[6]:
[('watch', 0.40134578943252563),
 ('at', 0.37023916840553284),
 ('next', 0.3551013469696045),
 ('things', 0.35425886511802673),
 ('those', 0.31279364228248596),
 ('refrigerators', 0.3126063644886017),
 ('not', 0.2820398807525635),
 ('major', 0.2547941207885742),
 ('all', 0.21463869512081146),
 ('blue', 0.20638896524906158)]

4.2. Doc2Vec

[7]:
class ProcessedDocs:
    def __init__(self, text):
        self.text = text

    def __iter__(self):
        for i, line in enumerate(self.text):
            doc_tags = ['human' if 'human' in line else 'thing']
            yield gensim.models.doc2vec.TaggedDocument(gensim.utils.simple_preprocess(line), doc_tags)
[8]:
dv_model = gensim.models.Doc2Vec()
dv_model.build_vocab(ProcessedDocs(text_corpus), update=False)
dv_model.train(ProcessedDocs(text_corpus), total_examples=dv_model.corpus_count, epochs=100)
[9]:
# available dictionary keys for words

dv_model.wv.key_to_index
[9]:
{'about': 0, 'houses': 1, 'not': 2, 'humans': 3, 'think': 4, 'tv': 5}
[10]:
# available dictionary keys for docs

dv_model.dv.key_to_index
[10]:
{'human': 0, 'thing': 1}
[11]:
# example of vector for a word

dv_model.wv['about'].shape, dv_model.wv['about'][:3]
[11]:
((100,), array([-0.001282  , -0.00032439,  0.0041051 ], dtype=float32))
[12]:
# example of vector for a doc

dv_model.dv['human'].shape, dv_model.dv['human'][:3]
[12]:
((100,), array([-0.01216876, -0.01115332, -0.01915727], dtype=float32))
[13]:
# infer similarity of a new document to known document tags

dv_model.dv.similar_by_vector(dv_model.infer_vector(["humans", "can"]))
[13]:
[('human', 0.3137548267841339), ('thing', 0.22879791259765625)]
[14]:
# infer similarity of each document in the dataset (as if they were new documents) to known document tags

print(" predict | correct")
for x in ProcessedDocs(text_corpus):
    print(" ", x.tags[0], " | ", dv_model.dv.similar_by_vector(dv_model.infer_vector(x.words))[0][0])
 predict | correct
  human  |  thing
  thing  |  thing
  human  |  thing
  thing  |  thing
  human  |  human
  thing  |  thing
  thing  |  human
  human  |  thing
  thing  |  human
  human  |  thing
  human  |  thing
  thing  |  thing
  thing  |  human
  thing  |  thing
  thing  |  thing
  thing  |  thing
[15]:
# infer similarity of each document in a new dataset to known document tags
new_text_corpus = [
    "humans can watch TV now",
    "You should not think about houses",
    "There is a TV inside those houses",
    "Good reasons to think",
]

print(" predict | correct")
for x in ProcessedDocs(new_text_corpus):
    print(" ", x.tags[0], " | ", dv_model.dv.similar_by_vector(dv_model.infer_vector(x.words))[0][0])
 predict | correct
  human  |  thing
  thing  |  thing
  thing  |  thing
  thing  |  human

4.3. Real data example with Word2Vec and LSTM

[16]:
db = torchtext.datasets.DBpedia(root='.data')
560000lines [00:24, 23132.59lines/s]
560000lines [00:43, 12879.42lines/s]
70000lines [00:04, 14107.03lines/s]
[17]:
db[0][-1]
[17]:
(13,
 tensor([     2,   2282,   6365,  25737,   2282,   6365,  25737,      6,    526,
              5,   3115,     21,      2,   2282,   6365,      9,    179,     10,
              7,   2785,   7137,     18,     19,   2064,    107,      8,    687,
             54,     33,     91,      6,    236,   6129,      3,   7137,     33,
            509,    187,    110,    696,    463,      4,    235,      3,   2786,
              2,    150,   2282,   6365,    463,    652,     48,    185,   1906,
              8,  12314,    524,      3,   1450,     21,    273,    177,      5,
         668703,   6505,      8,  32125,     12,      2,    301,     62,   7668,
           6934,    526,    792,      6,  30622,    135,      2,  43128,   2974,
             23,    347,   1884,      2,   2282,   6365,      3]))
[18]:
class DBProcessor:
    def __init__(self, text_train, text_test):
        self.text_train = text_train
        self.text_test = text_test

    def __iter__(self):
        for i, instance in enumerate(itertools.chain(self.text_train, self.text_test)):
            sentence = instance[1]
            yield list(sentence.numpy())
[19]:
vector_size = 100
pdb_train = DBProcessor(db[0], db[1])
dbp_model = gensim.models.Word2Vec(vector_size=vector_size, window=5, min_count=1, workers=4)
dbp_model.build_vocab(pdb_train, update=False)
dbp_model.train(pdb_train, total_examples=dbp_model.corpus_count, epochs=1)
[19]:
(25426807, 34419767)
[20]:
class EmbDatasetTorch(torch.utils.data.Dataset):
    def __init__(self, db, wv):
        self.db = db
        self.wv = wv

    def __getitem__(self, i):
        instance = self.db[i]
        doc_class = instance[0]
        sentence = instance[1]
        sentence_vecs = [self.wv[v.item()] for v in sentence if v.item() in self.wv.key_to_index.keys()]
        return doc_class, torch.as_tensor(np.column_stack(sentence_vecs).T)

    def __len__(self):
        return len(self.db)
[21]:
db_train = EmbDatasetTorch(db[0], dbp_model.wv)
db_test = EmbDatasetTorch(db[1], dbp_model.wv)
[22]:
next(iter(db_train))
[22]:
(0,
 tensor([[-0.6556, -0.8930,  1.3688,  ..., -1.6038, -2.4898,  2.5580],
         [-0.5155, -2.5213, -1.0673,  ..., -1.8014, -4.9374, -1.3045],
         [-1.3310, -0.9853,  2.1687,  ..., -1.5520, -1.2400,  1.2778],
         ...,
         [ 0.0188,  0.9136, -3.0507,  ..., -1.5575, -3.0111,  0.6496],
         [ 0.4291, -0.9791,  1.5837,  ..., -0.9062, -0.5298,  1.2703],
         [-0.5155, -2.5213, -1.0673,  ..., -1.8014, -4.9374, -1.3045]]))

Let’s definite our Pytorch lightning dataloader

[38]:
def collate_fn(x):
    labels, sentences = zip(*x)
    sentences = nn.utils.rnn.pack_sequence(sentences, enforce_sorted=False)
    labels = torch.LongTensor(labels)
    return sentences, labels

class DataModule(pl.LightningDataModule):
    def __init__(self, db_train, db_test, batch_size = 50,
                 num_workers=2, train_val_split_seed=0):
        super().__init__()

        self.batch_size = min(batch_size, len(db_train))
        self.num_workers = num_workers
        self.train_val_split_seed = train_val_split_seed
        self.db_train = db_train
        self.db_test = db_test

    def setup(self, stage):
        if stage == 'fit':
            full_dataset = self.db_train

            generator = torch.Generator().manual_seed(self.train_val_split_seed)
            full_size = len(full_dataset)
            val_size = min(full_size//10, 10000)
            partitions = [full_size - val_size, val_size]
            full_dataset = torch.utils.data.random_split(full_dataset, partitions,
                                                         generator=generator)
            self.train_dataset, self.val_dataset = full_dataset

        if stage == 'test':
            if self.db_test is not None:
                self.test_dataset = self.db_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True,
                          shuffle=True, num_workers=self.num_workers, collate_fn=collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          num_workers = self.num_workers, collate_fn=collate_fn)

    def test_dataloader(self):
        if self.db_test is None:
            raise RuntimeError("Test data not set")
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          num_workers = self.num_workers, collate_fn=collate_fn)
[39]:
datamodule = DataModule(db_train, db_test, num_workers=0)
datamodule.setup('fit')
[40]:
sentences, labels = next(iter(datamodule.train_dataloader()))

4.4. Defining main classes

[31]:
class LitNN(pl.LightningModule):
    def __init__(self, vector_size, n_classification_labels,
                 lstm_hidden_size = 150,
                 lstm_num_layers = 2,
                 lr=0.01, weight_decay=0):
        super().__init__()

        assert n_classification_labels != 1
        self.lr = lr
        self.weight_decay = weight_decay
        self.n_classification_labels = n_classification_labels

        input_size = vector_size
        self.lstm = nn.LSTM(
            input_size = vector_size,
            hidden_size = lstm_hidden_size,
            num_layers = lstm_num_layers,
        )
        self.last_layer = self._initialize_layer(nn.Linear(lstm_hidden_size, n_classification_labels))

    def forward(self, x):
        x = self.lstm(x)
        x = x[1][0][-1]
        x = self.last_layer(x)
        return x

    def _initialize_layer(self, layer):
        nn.init.constant_(layer.bias, 0)
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(layer.weight, gain=gain)
        return layer

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        inputv, target = train_batch
        output = self.forward(inputv)
        if self.n_classification_labels:
            loss = F.cross_entropy(output, target)
            self.log('train_loss_ce', loss.item())
        else:
            loss = F.mse_loss(output, target)
            self.log('train_loss_rmse', np.sqrt(loss.item()))

        return loss

    def test_validation_step(self, batch, batch_idx, name):
        inputv, target = batch
        output = self.forward(inputv)
        if self.n_classification_labels:
            loss_ce = F.cross_entropy(output, target).item()
            loss_zo = (torch.argmax(output, 1) != target)+0.
            loss_zo = loss_zo.mean().item()
            self.log(f'{name}_loss_ce', loss_ce)
            self.log(f'{name}_loss_zo', loss_zo)
        else:
            loss_mse = F.mse_loss(output, target).item()
            loss_mae = F.l1_loss(output, target).item()
            self.log(f'{name}_loss_rmse', np.sqrt(loss_mse))
            self.log(f'{name}_loss_mae', loss_mae)

    def validation_step(self, val_batch, batch_idx):
        self.test_validation_step(val_batch, batch_idx, 'val')

    def test_step(self, test_batch, batch_idx):
        self.test_validation_step(test_batch, batch_idx, 'test')

    def predict_step(self, predict_batch, batch_idx, dataloader_idx):
        inputv, target = predict_batch
        output = self.forward(inputv)
        return output
[42]:
datamodule = DataModule(db_train, db_test, batch_size=2048)
smodel = LitNN(vector_size=vector_size, n_classification_labels=14)

early_stop_callback = EarlyStopping(
   monitor='val_loss_ce',
   min_delta=0.00,
   patience=10,
   verbose=False,
   mode='min'
)

# use MLFlow as logger if available, see other options at
# https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html
# you can start MLFLow server with:
# mlflow server --backend-store-uri=./mlruns
try:
    from pytorch_lightning.loggers import MLFlowLogger
    logger = MLFlowLogger(
        experiment_name="Default",
        tracking_uri="file:./mlruns"
    )
except ImportError:
    # default: Tensorboard, you can start with:
    # tensorboard --logdir lightning_logs
    logger = True

trainer = pl.Trainer(
                     precision=32,
                     gpus=torch.cuda.device_count(),
                     #tpu_cores=None,
                     logger=logger,
                     val_check_interval=0.1, # do validation check 10 times for each epoch
                     #auto_scale_batch_size=True,
                     auto_lr_find=True,
                     callbacks=early_stop_callback,
                     max_epochs = 100,
                    )
GPU available: True, used: True
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
[43]:
trainer.tune(smodel, datamodule = datamodule)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type   | Params
--------------------------------------
0 | lstm       | LSTM   | 332 K
1 | last_layer | Linear | 2.1 K
--------------------------------------
334 K     Trainable params
0         Non-trainable params
334 K     Total params
1.338     Total estimated model params size (MB)

  | Name       | Type   | Params
--------------------------------------
0 | lstm       | LSTM   | 332 K
1 | last_layer | Linear | 2.1 K
--------------------------------------
334 K     Trainable params
0         Non-trainable params
334 K     Total params
1.338     Total estimated model params size (MB)
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{}
--------------------------------------------------------------------------------
Restored states from the checkpoint file at /home/marco/Documents/projects/python-intro/sections/lr_find_temp_model.ckpt
Restored states from the checkpoint file at /home/marco/Documents/projects/python-intro/sections/lr_find_temp_model.ckpt
Learning rate set to 0.19054607179632482
Learning rate set to 0.19054607179632482
[43]:
{'lr_find': <pytorch_lightning.tuner.lr_finder._LRFinder at 0x7f1bfae19400>}
[44]:
# fit smodel
trainer.fit(smodel, datamodule = datamodule)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type   | Params
--------------------------------------
0 | lstm       | LSTM   | 332 K
1 | last_layer | Linear | 2.1 K
--------------------------------------
334 K     Trainable params
0         Non-trainable params
334 K     Total params
1.338     Total estimated model params size (MB)

  | Name       | Type   | Params
--------------------------------------
0 | lstm       | LSTM   | 332 K
1 | last_layer | Linear | 2.1 K
--------------------------------------
334 K     Trainable params
0         Non-trainable params
334 K     Total params
1.338     Total estimated model params size (MB)
[45]:
trainer.test(smodel, datamodule = datamodule)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss_ce': 0.5150100588798523, 'test_loss_zo': 0.1482437252998352}
--------------------------------------------------------------------------------
[45]:
[{'test_loss_ce': 0.5150100588798523, 'test_loss_zo': 0.1482437252998352}]
[46]:
# predict smodel
data_loader = torch.utils.data.Subset(db_train, range(10))
data_loader = DataLoader(data_loader, collate_fn=collate_fn)
test_pred = trainer.predict(smodel, data_loader)
test_pred = [F.softmax(t, 1).cpu() for t in test_pred]
test_pred = np.vstack(test_pred)
test_pred
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: DeprecationWarning: The `LightningModule.datamodule` property is deprecated in v1.3 and will be removed in v1.5. Access the datamodule through using `self.trainer.datamodule` instead.
  warnings.warn(*args, **kwargs)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/marco/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, predict dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
[46]:
array([[4.09552157e-01, 1.06071964e-01, 8.41988176e-02, 2.59270594e-02,
        3.22150514e-02, 5.33441489e-04, 5.33908457e-02, 3.97227984e-03,
        1.80906039e-02, 3.13700511e-05, 7.64723518e-04, 9.98726246e-05,
        4.46949855e-07, 2.65151352e-01],
       [9.75617692e-02, 1.64559879e-07, 4.24981059e-04, 7.47480681e-06,
        6.20739684e-06, 5.22396527e-03, 1.70907925e-03, 2.12730301e-04,
        1.85641170e-06, 1.91530702e-03, 5.31282611e-02, 1.37102649e-01,
        3.89551185e-02, 6.63750350e-01],
       [8.68686259e-01, 1.33589652e-04, 1.16329047e-05, 1.16774821e-08,
        1.95751545e-06, 2.35673203e-03, 5.91394305e-02, 3.88082638e-02,
        3.63222134e-05, 7.82081827e-07, 1.42352932e-04, 2.43386975e-03,
        1.07618331e-10, 2.82487869e-02],
       [7.53526568e-01, 2.33273354e-07, 2.88229767e-05, 5.08333073e-07,
        4.38038005e-05, 4.07082103e-02, 4.40885797e-02, 9.62672383e-02,
        1.27165418e-04, 1.53450109e-03, 1.51878968e-02, 4.36797366e-03,
        7.46173373e-06, 4.41109948e-02],
       [2.87303030e-01, 1.16851628e-02, 1.08578301e-04, 3.28232352e-09,
        6.70920999e-05, 6.95538940e-03, 6.83973789e-01, 5.42083988e-04,
        3.53601922e-06, 3.26784333e-09, 5.43062072e-07, 7.72331332e-05,
        3.48732050e-11, 9.28363111e-03],
       [6.40867278e-02, 1.21500306e-01, 9.95871029e-04, 5.26127842e-05,
        3.09540446e-05, 3.12927623e-05, 4.73763747e-03, 5.42093294e-05,
        1.59608258e-04, 2.08304823e-06, 3.71216920e-05, 4.27920138e-04,
        5.87089471e-06, 8.07877779e-01],
       [4.58942205e-01, 5.02900593e-06, 4.93392348e-04, 3.42837564e-04,
        2.99833203e-03, 2.20704935e-02, 4.25901487e-02, 2.71155648e-02,
        4.55437973e-03, 6.51715994e-02, 3.30148749e-02, 7.66062527e-04,
        3.78116369e-02, 3.04123491e-01],
       [3.61741275e-01, 2.80804306e-05, 2.67065698e-05, 3.83088200e-10,
        2.48303586e-05, 5.98860420e-02, 5.72525144e-01, 2.20764731e-03,
        6.49348181e-07, 2.63451909e-08, 5.55272118e-06, 3.56301782e-04,
        9.63145605e-11, 3.19772563e-03],
       [8.64959776e-01, 1.35151431e-05, 5.78636946e-06, 3.87738730e-09,
        1.05855952e-06, 4.55121743e-03, 5.02642728e-02, 5.95776401e-02,
        1.58676849e-05, 1.23718848e-06, 2.75576807e-04, 3.97440372e-03,
        9.71813185e-11, 1.63596980e-02],
       [2.42577583e-01, 4.70198188e-07, 9.49133828e-05, 3.02213117e-08,
        6.46146887e-04, 3.61857921e-01, 3.85332048e-01, 2.62155617e-03,
        3.56351916e-06, 1.71235315e-05, 1.89578263e-04, 3.00951971e-04,
        3.18828302e-06, 6.35486608e-03]], dtype=float32)
[47]:
smodel.trainer.callback_metrics
[47]:
{'val_loss_ce': tensor(0.4905),
 'val_loss_zo': tensor(0.1475),
 'train_loss_ce': tensor(0.5414, device='cuda:0'),
 'test_loss_ce': tensor(0.5150),
 'test_loss_zo': tensor(0.1482)}
[48]:
# check if smodel if is pickable
pickle.dumps(smodel);