In this post, we will train a Word2Vec skip-gram model from scratch on some text and inspect the trained embeddings at the end.
The first step for using any kind of NLP pipeline is to vectorize the text. Traditionally, we used to do this using one-hot representation of vectors. This had various downsides like:
Vectors tend to be very long as their size depends on the vocabulary size of the corpus(which grows with the corpus size).
They don’t have any understanding of the text.
They are sparse and can’t be used for any comparison as any two one-hot encoded vector from a set will be orthogonal.
Word2vec is able to tackle all these challenges. The architecture we define later will enable us to learn a distributed representation for each word/token in our corpus. The key idea being, we can represent a word by adding contextual information to it. Context refers to the words/tokens neighbouring the word/token of interest to us.
Of course this technique is not perfect and has it’s own downsides. The key downside being we loose the word/token ordering while training the word2vec model, hence we are essentially dealing with a bag of words model. This makes the resulting embeddings not suitable for sentence level representations. We will cover other techniques later (RNNs, Transformers) that produce embeddings more suitable for sentences. For now our sole focus will be learning good word/token level embeddings.
The whole process of training the embeddings can be broken down into the following key steps:
We define a class TextProcessor that has methods to pre-process the text.
fetch_text() fetches the texts directly from urls containing our text data.
process_text() applies basic pre-procesing steps to the text and creates tokens using white space tokenization.
prepare_vocab() creates token to index mapping and vice versa.
Create torch Dataset and Dataloader where each index in the dataset will be a tuple, (center_word_ix, context_word_ix); i.e. pairs of indices of center word and context based on a pre defined window size.
The window side determines the number of context words around a center word.
For example, window size of 5 will create the follwing center and context pairs for the sentence “a quick brown fox jumps” with “brown” being the center word - [("brown", "a"), ("brown", "quick"), ("brown", "fox"), ("brown", "jumps")]
Create our SkipGramModel model by defining a custom model on top of the torch.nn.Module.
Define our loss function, optimizer and scheduler.
Train the model for a certain number of epochs.
Finally, we will create a similarity matrix(cosine similarity) using the trained embeddings and explore similar embeddings for some word query.
Fetching and preparing the data
1
2
3
4
5
6
7
8
9
10
# We will be using the first four Harry Potter books.BASE_URL="https://raw.githubusercontent.com/formcept/"BASE_URL+="whiteboard/master/nbviewer/notebooks/data/harrypotter"BOOK_URLS=[f"{BASE_URL}/Book%201%20-%20The%20Philosopher's%20Stone.txt",f"{BASE_URL}/Book%202%20-%20The%20Chamber%20of%20Secrets.txt",f"{BASE_URL}/Book%203%20-%20The%20Prisoner%20of%20Azkaban.txt",f"{BASE_URL}/Book%204%20-%20The%20Goblet%20of%20Fire.txt"]
classTextProcessor:def__init__(self,urls:List[str]):self.urls=urlsdeffetch_text(self):self.text=""fori,urlinenumerate(self.urls):r=requests.get(url)self.text+="\n\n\n"+r.textprint(f"Fetched {i+1}/{len(self.urls)} urls ...")defprocess_text(self):print("Processing text ...")sentences=re.split("\n{2,}",self.text)print(f"{len(sentences)} sentences")self.clean_sentences=[]fortxtinsentences:txt=txt.strip().lower()txt=txt.replace("\n","")txt=" ".join(re.findall("\w+",txt))iftxt:self.clean_sentences.append(txt)# Tokens and token counter will be used later to create a vocabulary for# us which will be used to map words/tokens to indices and vice versa, to# create training data and recreate the texts from predictions respectively.print(f"{len(sentences)} filtered sentences")tokens=" ".join(self.clean_sentences).split(" ")print(f"{len(tokens)} tokens")self.token_counter=Counter(tokens)print(f"{len(self.token_counter)} unique tokens")defprepare_vocab(self,min_count:int):# min_count is the minimum number of times a token should appear in the# text to be considered in the vocabulary else they are assigned to a# default index which is equal to `len(vocab)`.self.w2ix={}fori,(token,count)inenumerate(self.token_counter.most_common(len(self.token_counter))):ifcount<min_count:breakelse:self.w2ix[token]=iself.ix2w={ix:wforw,ixinself.w2ix.items()}# Assign default index to rest of the tokensn=len(self.w2ix)fortoken,_inself.token_counter.most_common(len(self.token_counter)):iftokennotinself.w2ix:self.w2ix[token]=nself.ix2w[n]="<unk>"self.vocab_size=n+1print(f"Vocab size: {self.vocab_size}")
classHarryPotterDataset(Dataset):def__init__(self,text_processor:TextProcessor,window_size:int=3):self.data=[]self.text_processor=text_processorself._init_data()def_init_data(self):fortxtintext_processor.clean_sentences:splits=txt.split(" ")foriinrange(window_size,len(splits)-1):center_word=splits[i]window_words=splits[i-window_size:i]+splits[i+1:i+window_size+1]forcontextinwindow_words:# Each data point under self.data will be a tuple with index 0# containing the index(w2ix) for center_word and# index 1 containing the index(w2ix) for context_wordself.data.append((text_processor.w2ix[center_word],text_processor.w2ix[context]))def__getitem__(self,ix:int):returnself.data[ix]def__len__(self):returnlen(self.data)
The model architecure is Embedding -> Linear -> Softmax.
Embedding is nothing but “A simple lookup table that stores embeddings of a fixed
dictionary and size” - pytorch documentation.
Based on our model architecture, this layer will contain the trained token/word
embeddings of interest to us and we will discard the weight matrix from
the Linear layer.
In each forward pass we:
Pass indices of center word as input x.
These indices are looked up in the self.embedding table.
We transform the vector from last step using the Linear layer to get another vector with dimension vocab_size.
Finally we apply softmax and calcualte loss.
A trained model should output high values for the context indices for a given
center word.
defget_similar_words(input_word:str,n:int):ifinput_wordnotintext_processor.w2ix:print("word not in vocab")else:input_word_ix=text_processor.w2ix[input_word]similarity_vector=similarity[input_word_ix]most_similar_ixs=np.argsort(similarity_vector)[-n:][::-1]most_similar_words=[(text_processor.ix2w[ix],similarity_vector[ix])forixinmost_similar_ixs]forword,scoreinmost_similar_words:print(f" > {word}, score={score: .2f}")
There are various Intrisic and Extrinsic evaluation criterias for these embeddings.
For Evaluation and Interpretation of the embeddings take a look here.