I implemented a word n-gram model using a character ternary search tree. It is intended to be passed a generator that yields a long sequence of words (from a corpus) and its requirements are that it
- can return frequencies and probabilities for word n-grams
- allows providing a vocabulary, such that n-grams containing words not in the vocabulary are not counted
- allows providing a list of targets, so that only n-grams ending in a target are counted (as probabilities for these are needed)
I find that it works as expected, but it is quite slow and consumes a lot of memory. For n-grams of length 4, trained on a corpus of about 1 billion words, right now it consumes>120GB of memory, despite providing a vocabulary that consists of words with a minimum frequency of 5, and it has been running for nearly 30 already. I know that Python requires a lot of memory, but I'm wondering if I'm missing something that would make it faster and maybe less memory intensive.
File tst.py:
class Node():
def __init__(self, char):
self.char = char
self.count = 0
self.lo = None
self.eq = None
self.hi = None
class TernarySearchTree():
"""Ternary search tree that stores counts for n-grams
and their subsequences.
"""
def __init__(self, splitchar=None):
"""Initializes TST.
Parameters
----------
splitchar : str
Character that separates tokens in n-gram.
Counts are stored for complete n-grams and
each sub-sequence ending in this character
"""
self._root = None
self._splitchar = splitchar
self._total = 0
def insert(self, string):
"""Insert string into Tree.
Parameters
----------
string : str
String to be inserted.
"""
self._root = self._insert(string, self._root)
self._total += 1
def frequency(self, string):
"""Return frequency of string.
Parameters
----------
string : str
Returns
-------
int
Frequency
"""
if not string:
return self._total
node = self._search(string, self._root)
if not node:
return 0
return node.count
def _insert(self, string, node):
"""Insert string at a given node.
"""
if not string:
return node
char, *rest = string
if node is None:
node = Node(char)
if char == node.char:
if not rest:
node.count += 1
return node
else:
if rest[0] == self.splitchar:
node.count += 1
node.eq = self._insert(rest, node.eq)
elif char < node.char:
node.lo = self._insert(string, node.lo)
else:
node.hi = self._insert(string, node.hi)
return node
def _search(self, string, node):
"""Return node that string ends in.
"""
if not string or not node:
return node
char, *rest = string
if char == node.char:
if not rest:
return node
return self._search(rest, node.eq)
elif char < node.char:
return self._search(string, node.lo)
else:
return self._search(string, node.hi)
def __contains__(self, string):
"""Adds "string in TST" syntactic sugar.
"""
node = self._search(string, self._root)
if node:
return node.count
return False
@property
def splitchar(self):
return self._splitchar
File language_model.py:
from collections import deque
from tst import TernarySearchTree
class ContainsEverything:
"""Dummy container that mimics containing everything.
Has .add() method to mimic set.
"""
def __contains__(self, _):
return True
def add(self, _):
pass
class LanguageModel():
"""N-gram (Markov) model that uses a ternary search tree.
Tracks frequencies and calculates probabilities.
Attributes
----------
n : int
Size of n-grams to be tracked.
vocabulary : set
If provided, n-grams containing words not in vocabulary are skipped.
Can be other container than set, if it has add method.
targets : container
If provided, n-grams not ending in target are counted as
ending in "OOV" (OutOfVocabulary) instead, so probabilities
can still be calculated.
boundary : str
N-grams crossing boundary will not be counted,
e.g. sentence </s> or document </doc> meta tags
splitchar : str
String that separates tokens in n-grams
"""
def __init__(self, n, boundary="</s>", splitchar="#",
vocabulary=None, targets=None):
"""
Parameters
----------
n : int
Size of n-grams to be tracked.
boundary : str
N-grams crossing boundary will not be counted,
e.g. sentence </s> or document </doc> meta tags
splitchar : str
String that separates tokens in n-grams
vocabulary : set
If provided, n-grams with words not in vocabulary are skipped.
Can be other container than set, if it has add method.
targets : container
If provided, n-grams not ending in target are counted as
ending in "OOV" (OutOfVocabulary) instead, so probabilities
can still be calculated.
"""
if not targets:
targets = ContainsEverything()
if not vocabulary:
vocabulary = ContainsEverything()
self._n = n
self._counts = TernarySearchTree(splitchar)
self._vocabulary = vocabulary
self._targets = targets
self._boundary = boundary
self._splitchar = splitchar
def train(self, sequence):
"""Train model on all n-grams in sequence.
Parameters
----------
sequence : iterable of str
Sequence of tokens to train on.
Notes
-----
A sequence [A, B, C, D, E] with n==3 will result in these
n-grams:
[A, B, C]
[B, C, D]
[C, D, E]
[D, E]
[E]
"""
n_gram = deque(maxlen=self.n)
for element in sequence:
if element == self.boundary:
# train on smaller n-grams at end of sentence
# but exclude full n_gram if it was already trained
# on in last iteration
not_trained = len(n_gram) < self.n
for length in range(1, len(n_gram) + not_trained):
self._train(list(n_gram)[-length:])
n_gram.clear()
continue
n_gram.append(element)
if len(n_gram) == self.n:
if element not in self.targets:
self._train(list(n_gram)[:-1])
continue
self._train(n_gram)
# train on last n-grams in sequence
# ignore full n-gram if it has already been trained on
if len(n_gram) == self.n:
n_gram = list(n_gram)[1:]
for length in range(1, len(n_gram) + 1):
self._train(list(n_gram)[-length:])
def probability(self, sequence):
"""Returns probability of the sequence.
Parameters
----------
sequence : iterable of str
Sequence of tokens to get the probability for
Returns
-------
float or list of float
Probability of last element or probabilities of all elements
"""
try:
n_gram = sequence[-self.n:]
# if sequence is generator (cannot slice - TypeError),
# run through it and return probability for final element
except TypeError:
n_gram = deque(maxlen=self.n)
for element in sequence:
n_gram.append(element)
probability = self._probability(n_gram)
return probability
def frequency(self, n_gram):
"""Return frequency of n_gram.
Parameters
----------
n_gram : list/tuple of str
Returns
-------
int
Frequency
"""
n_gram_string = self.splitchar.join(n_gram)
frequency = self._counts.frequency(n_gram_string)
return frequency
def _train(self, n_gram):
# test for OOV words
for idx, word in enumerate(n_gram):
if word not in self.vocabulary:
n_gram = list(n_gram)[:idx]
n_gram_string = self.splitchar.join(n_gram)
self._counts.insert(n_gram_string)
def _probability(self, n_gram):
frequency = self.frequency(n_gram)
if frequency == 0:
return 0
*preceding, target = n_gram
total = self.frequency(preceding)
probability = frequency / total
return probability
def __contains__(self, n_gram):
return n_gram in self._counts
@property
def n(self):
return self._n
@property
def vocabulary(self):
return self._vocabulary
@property
def targets(self):
return self._targets
@property
def boundary(self):
return self._boundary
@property
def splitchar(self):
return self._splitchar
For testing with random strings:
import random
from string import ascii_letters
from language_model import LanguageModel
def generate_random_strings(num):
random.seed(69)
for i in range(num):
length = random.choice(range(12))
yield "".join(random.choices(ascii_letters, k=length))
lm = LanguageModel(4)
lm.train(generate_random_strings(1000000))
1 Answer 1
class ContainsEverything: """Dummy container that mimics containing everything. Has .add() method to mimic set. """
The indentation is borked here, and needs correcting before the code will run.
I profiled with guppy3
(inlining everything into one file for my convenience - that explains the __main__
below):
lm = LanguageModel(4)
lm.train(generate_random_strings(10000))
from guppy import hpy
h = hpy(lm)
print(h.heap())
Nearly all of the memory was accounted for in the top two lines:
Partition of a set of 502128 objects. Total size = 43108362 bytes.
Index Count % Size % Cumulative % Kind (class / dict of class)
0 233701 47 26174512 61 26174512 61 dict of __main__.Node
1 233701 47 13087256 30 39261768 91 __main__.Node
I'm not entirely sure where the dict of __main__.Node
comes from, but clearly Node
is the culprit, and each node is contributing 56 bytes inherently and a further 112 bytes indirectly.
However, looking at the use of the tree:
self._counts = TernarySearchTree(splitchar) ... frequency = self._counts.frequency(n_gram_string) ... self._counts.insert(n_gram_string) ... return n_gram in self._counts
I can't see any reason to use a tree. There's no use of internal nodes. As far as I can see, it can easily be replaced by a Counter
, whereupon the memory usage drops by 90%:
Partition of a set of 34724 objects. Total size = 4141426 bytes.
Index Count % Size % Cumulative % Kind (class / dict of class)
0 9838 28 884914 21 884914 21 str
1 8509 25 613768 15 1498682 36 tuple
2 415 1 365544 9 1864226 45 type
3 2226 6 320544 8 2184770 53 types.CodeType
4 4329 12 309952 7 2494722 60 bytes
5 1 0 295024 7 2789746 67 collections.Counter
Explore related questions
See similar questions with these tags.