Revisiting Word2Vec skip-gram model

In this post, we will train a Word2Vec skip-gram model from scratch on some text and inspect the trained embeddings at the end.

The first step for using any kind of NLP pipeline is to vectorize the text. Traditionally, we used to do this using one-hot representation of vectors. This had various downsides like:

  • Vectors tend to be very long as their size depends on the vocabulary size of the corpus(which grows with the corpus size).
  • They don’t have any understanding of the text.
  • They are sparse and can’t be used for any comparison as any two one-hot encoded vector from a set will be orthogonal.

Word2vec is able to tackle all these challenges. The architecture we define later will enable us to learn a distributed representation for each word/token in our corpus. The key idea being, we can represent a word by adding contextual information to it. Context refers to the words/tokens neighbouring the word/token of interest to us.

Of course this technique is not perfect and has it’s own downsides. The key downside being we loose the word/token ordering while training the word2vec model, hence we are essentially dealing with a bag of words model. This makes the resulting embeddings not suitable for sentence level representations. We will cover other techniques later (RNNs, Transformers) that produce embeddings more suitable for sentences. For now our sole focus will be learning good word/token level embeddings.

First we will import the necessary libraries

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
from collections import Counter
import re
from typing import List

import numpy as np
import requests
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

The whole process of training the embeddings can be broken down into the following key steps:

  1. We define a class TextProcessor that has methods to pre-process the text.
  2. fetch_text() fetches the texts directly from urls containing our text data.
  3. process_text() applies basic pre-procesing steps to the text and creates tokens using white space tokenization.
  4. prepare_vocab() creates token to index mapping and vice versa.
  5. Create torch Dataset and Dataloader where each index in the dataset will be a tuple, (center_word_ix, context_word_ix); i.e. pairs of indices of center word and context based on a pre defined window size.
    1. The window side determines the number of context words around a center word.
    2. For example, window size of 5 will create the follwing center and context pairs for the sentence “a quick brown fox jumps” with “brown” being the center word -
      [("brown", "a"), ("brown", "quick"), ("brown", "fox"), ("brown", "jumps")]
  6. Create our SkipGramModel model by defining a custom model on top of the torch.nn.Module.
  7. Define our loss function, optimizer and scheduler.
  8. Train the model for a certain number of epochs.

Finally, we will create a similarity matrix(cosine similarity) using the trained embeddings and explore similar embeddings for some word query.

Fetching and preparing the data

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# We will be using the first four Harry Potter books.
BASE_URL = "https://raw.githubusercontent.com/formcept/"
BASE_URL += "whiteboard/master/nbviewer/notebooks/data/harrypotter"

BOOK_URLS = [
    f"{BASE_URL}/Book%201%20-%20The%20Philosopher's%20Stone.txt",
    f"{BASE_URL}/Book%202%20-%20The%20Chamber%20of%20Secrets.txt",
    f"{BASE_URL}/Book%203%20-%20The%20Prisoner%20of%20Azkaban.txt",
    f"{BASE_URL}/Book%204%20-%20The%20Goblet%20of%20Fire.txt"
]
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class TextProcessor:
    def __init__(self, urls: List[str]):
        self.urls = urls

    def fetch_text(self):
        self.text = ""
        for i, url in enumerate(self.urls):
            r = requests.get(url)
            self.text += "\n\n\n" + r.text
            print(f"Fetched {i+1}/{len(self.urls)} urls ...")

    def process_text(self):
        print("Processing text ...")
        sentences = re.split("\n{2,}", self.text)
        print(f"{len(sentences)} sentences")
        self.clean_sentences = []
        for txt in sentences:
            txt = txt.strip().lower()
            txt = txt.replace("\n", "")
            txt = " ".join(re.findall("\w+", txt))
            if txt:
                self.clean_sentences.append(txt)

        # Tokens and token counter will be used later to create a vocabulary for
        # us which will be used to map words/tokens to indices and vice versa, to
        # create training data and recreate the texts from predictions respectively.
        print(f"{len(sentences)} filtered sentences")
        tokens = " ".join(self.clean_sentences).split(" ")
        print(f"{len(tokens)} tokens")
        self.token_counter = Counter(tokens)
        print(f"{len(self.token_counter)} unique tokens")

    def prepare_vocab(self, min_count: int):
        # min_count is the minimum number of times a token should appear in the
        # text to be considered in the vocabulary else they are assigned to a
        # default index which is equal to `len(vocab)`.
        self.w2ix = {}
        for i, (token, count) in enumerate(
            self.token_counter.most_common(len(self.token_counter))
        ):
            if count < min_count:
                break
            else:
                self.w2ix[token] = i

        self.ix2w = {ix: w for w, ix in self.w2ix.items()}

        # Assign default index to rest of the tokens
        n = len(self.w2ix)
        for token, _ in self.token_counter.most_common(len(self.token_counter)):
            if token not in self.w2ix:
                self.w2ix[token] = n
        self.ix2w[n] = "<unk>"

        self.vocab_size = n + 1
        print(f"Vocab size: {self.vocab_size}")
1
2
3
4
text_processor = TextProcessor(BOOK_URLS)
text_processor.fetch_text()
text_processor.process_text()
text_processor.prepare_vocab(min_count=20)
Fetched 1/4 urls ...
Fetched 2/4 urls ...
Fetched 3/4 urls ...
Fetched 4/4 urls ...

Processing text ...
20033 sentences
20033 filtered sentences
501854 tokens
15078 unique tokens

Vocab size: 2103

Preparing the torch Dataset and Dataloader

The Dataloader will enable us to send data in batches during the forward pass of the model training.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class HarryPotterDataset(Dataset):
    def __init__(self, text_processor: TextProcessor,  window_size: int=3):
        self.data = []
        self.text_processor = text_processor
        self._init_data()
    
    def _init_data(self):
        for txt in text_processor.clean_sentences:
            splits = txt.split(" ")
            for i in range(window_size, len(splits) -1):
                center_word = splits[i]
                window_words = splits[i-window_size:i]+splits[i+1:i+window_size+1]
                for context in window_words:
                    # Each data point under self.data will be a tuple with index 0
                    # containing the index(w2ix) for center_word and
                    # index 1 containing the index(w2ix) for context_word
                    self.data.append(
                        (text_processor.w2ix[center_word], text_processor.w2ix[context])
                    )

    def __getitem__(self, ix: int):
        return self.data[ix]
  
    def __len__(self):
        return len(self.data)
1
2
3
dataset = HarryPotterDataset(text_processor, 3)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
len(dataset)
2489461

Creating the Word2Vec model using torch.nn.Module

1
2
3
4
5
6
7
8
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int=50):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        return self.linear(self.embedding(x))

The model architecure is Embedding -> Linear -> Softmax.

Embedding is nothing but “A simple lookup table that stores embeddings of a fixed dictionary and size” - pytorch documentation.

Based on our model architecture, this layer will contain the trained token/word embeddings of interest to us and we will discard the weight matrix from the Linear layer.

In each forward pass we:

  1. Pass indices of center word as input x.
  2. These indices are looked up in the self.embedding table.
  3. We transform the vector from last step using the Linear layer to get another vector with dimension vocab_size.
  4. Finally we apply softmax and calcualte loss.

A trained model should output high values for the context indices for a given center word.

1
2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkipGramModel(text_processor.vocab_size, embedding_size=30).to(device)

Defining our loss function, optimizer and scheduler

1
2
3
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
1
2
N_EPOCHS =       10
STEP_TO_LOG = 10000

Training the model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
epoch_losses, batch_losses = [], []

for epoch in range(N_EPOCHS):
    losses = []
    for i, data in enumerate(dataloader):
        inputs, outputs = data[0].long().to(device), data[1].long().to(device)
        pred = model(inputs)
        loss = criterion(pred, outputs)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % STEP_TO_LOG == 0:
            print(f"epoch {epoch + 1}; steps={i:>5}/{len(dataloader)}; loss={loss:<.5f}")
    scheduler.step()
    batch_losses += losses
    epoch_losses.append(np.mean(losses))
    print(f"epoch {epoch + 1} ; lr={scheduler._last_lr[0]}; loss={epoch_losses[-1]:<.5f}\n")
epoch 1; steps=    0/38898; loss=7.80020
epoch 1; steps=10000/38898; loss=5.55375
epoch 1; steps=20000/38898; loss=5.58139
epoch 1; steps=30000/38898; loss=5.27452
epoch 1 ; lr=0.005; loss=5.60440

epoch 2; steps=    0/38898; loss=4.85012
epoch 2; steps=10000/38898; loss=5.81531
epoch 2; steps=20000/38898; loss=5.07261
epoch 2; steps=30000/38898; loss=5.62864
epoch 2 ; lr=0.005; loss=5.52841

epoch 3; steps=    0/38898; loss=5.27015
epoch 3; steps=10000/38898; loss=5.55336
epoch 3; steps=20000/38898; loss=5.67527
epoch 3; steps=30000/38898; loss=5.60372
epoch 3 ; lr=0.005; loss=5.52233

epoch 4; steps=    0/38898; loss=5.50681
epoch 4; steps=10000/38898; loss=5.48357
epoch 4; steps=20000/38898; loss=5.18897
epoch 4; steps=30000/38898; loss=5.74382
epoch 4 ; lr=0.005; loss=5.52111

epoch 5; steps=    0/38898; loss=5.67365
epoch 5; steps=10000/38898; loss=5.66463
epoch 5; steps=20000/38898; loss=5.82155
epoch 5; steps=30000/38898; loss=5.77744
epoch 5 ; lr=0.0025; loss=5.52081

epoch 6; steps=    0/38898; loss=5.54513
epoch 6; steps=10000/38898; loss=5.40478
epoch 6; steps=20000/38898; loss=5.71839
epoch 6; steps=30000/38898; loss=5.23344
epoch 6 ; lr=0.0025; loss=5.47719

epoch 7; steps=    0/38898; loss=4.93150
epoch 7; steps=10000/38898; loss=5.09374
epoch 7; steps=20000/38898; loss=5.47905
epoch 7; steps=30000/38898; loss=5.92455
epoch 7 ; lr=0.0025; loss=5.47128

epoch 8; steps=    0/38898; loss=5.22414
epoch 8; steps=10000/38898; loss=5.67435
epoch 8; steps=20000/38898; loss=5.55853
epoch 8; steps=30000/38898; loss=5.28791
epoch 8 ; lr=0.0025; loss=5.47099

epoch 9; steps=    0/38898; loss=5.33284
epoch 9; steps=10000/38898; loss=5.80998
epoch 9; steps=20000/38898; loss=6.04806
epoch 9; steps=30000/38898; loss=5.58979
epoch 9 ; lr=0.0025; loss=5.47106

epoch 10; steps=    0/38898; loss=5.09184
epoch 10; steps=10000/38898; loss=5.18391
epoch 10; steps=20000/38898; loss=5.66316
epoch 10; steps=30000/38898; loss=5.44824
epoch 10 ; lr=0.00125; loss=5.47111
1
2
weights = model.embedding.weight.detach().cpu().numpy()
weights.shape
(2103, 30)

Creating the similarity matrix

1
similarity = cosine_similarity(weights, weights)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def get_similar_words(input_word: str, n: int):
    if input_word not in text_processor.w2ix:
        print("word not in vocab")
    else:
        input_word_ix = text_processor.w2ix[input_word]
        similarity_vector = similarity[input_word_ix]
        most_similar_ixs = np.argsort(similarity_vector)[-n:][::-1]
        most_similar_words = [(text_processor.ix2w[ix], similarity_vector[ix])
                              for ix in most_similar_ixs]
        for word, score in most_similar_words:
            print(f" > {word}, score={score: .2f}")

Exploring similar embeddings for some queries

1
get_similar_words("snape", 10)
 > snape, score= 1.00
 > karkaroff, score= 0.81
 > dumbledore, score= 0.79
 > lupin, score= 0.77
 > mcgonagall, score= 0.76
 > moody, score= 0.74
 > professor, score= 0.72
 > flitwick, score= 0.71
 > trelawney, score= 0.69
 > quirrell, score= 0.69
1
get_similar_words("vernon", 10)
 > vernon, score= 1.00
 > aunt, score= 0.81
 > uncle, score= 0.79
 > dudley, score= 0.78
 > petunia, score= 0.70
 > marge, score= 0.62
 > furious, score= 0.59
 > telephone, score= 0.53
 > sister, score= 0.53
 > company, score= 0.52
1
get_similar_words("muggle", 10)
 > muggle, score= 1.00
 > wizarding, score= 0.81
 > old, score= 0.65
 > yer, score= 0.63
 > international, score= 0.61
 > important, score= 0.61
 > our, score= 0.61
 > is, score= 0.61
 > young, score= 0.60
 > yourself, score= 0.60
1
get_similar_words("harry", 10)
 > harry, score= 1.00
 > lupin, score= 0.64
 > ron, score= 0.63
 > hermione, score= 0.61
 > she, score= 0.57
 > he, score= 0.57
 > colin, score= 0.54
 > dumbledore, score= 0.52
 > furiously, score= 0.51
 > moody, score= 0.49
1
get_similar_words("slytherin", 10)
 > slytherin, score= 1.00
 > gryffindor, score= 0.79
 > hufflepuff, score= 0.75
 > team, score= 0.71
 > heir, score= 0.71
 > ravenclaw, score= 0.66
 > goal, score= 0.64
 > seeker, score= 0.64
 > wood, score= 0.60
 > irish, score= 0.60
1
get_similar_words("malfoy", 10)
 > malfoy, score= 1.00
 > goyle, score= 0.77
 > crabbe, score= 0.64
 > draco, score= 0.61
 > diggory, score= 0.59
 > lucius, score= 0.57
 > weasley, score= 0.51
 > laughing, score= 0.50
 > pettigrew, score= 0.50
 > percy, score= 0.50

There are various Intrisic and Extrinsic evaluation criterias for these embeddings. For Evaluation and Interpretation of the embeddings take a look here.