6
\$\begingroup\$

I implemented a word n-gram model using a character ternary search tree. It is intended to be passed a generator that yields a long sequence of words (from a corpus) and its requirements are that it

  • can return frequencies and probabilities for word n-grams
  • allows providing a vocabulary, such that n-grams containing words not in the vocabulary are not counted
  • allows providing a list of targets, so that only n-grams ending in a target are counted (as probabilities for these are needed)

I find that it works as expected, but it is quite slow and consumes a lot of memory. For n-grams of length 4, trained on a corpus of about 1 billion words, right now it consumes>120GB of memory, despite providing a vocabulary that consists of words with a minimum frequency of 5, and it has been running for nearly 30 already. I know that Python requires a lot of memory, but I'm wondering if I'm missing something that would make it faster and maybe less memory intensive.

File tst.py:

class Node():
 def __init__(self, char):
 self.char = char
 self.count = 0
 self.lo = None
 self.eq = None
 self.hi = None
class TernarySearchTree():
 """Ternary search tree that stores counts for n-grams
 and their subsequences.
 """
 def __init__(self, splitchar=None):
 """Initializes TST.
 Parameters
 ----------
 splitchar : str
 Character that separates tokens in n-gram.
 Counts are stored for complete n-grams and
 each sub-sequence ending in this character
 """
 self._root = None
 self._splitchar = splitchar
 self._total = 0
 def insert(self, string):
 """Insert string into Tree.
 Parameters
 ----------
 string : str
 String to be inserted.
 """
 self._root = self._insert(string, self._root)
 self._total += 1
 def frequency(self, string):
 """Return frequency of string.
 Parameters
 ----------
 string : str
 Returns
 -------
 int
 Frequency
 """
 if not string:
 return self._total
 node = self._search(string, self._root)
 if not node:
 return 0
 return node.count
 def _insert(self, string, node):
 """Insert string at a given node.
 """
 if not string:
 return node
 char, *rest = string
 if node is None:
 node = Node(char)
 if char == node.char:
 if not rest:
 node.count += 1
 return node
 else:
 if rest[0] == self.splitchar:
 node.count += 1
 node.eq = self._insert(rest, node.eq)
 elif char < node.char:
 node.lo = self._insert(string, node.lo)
 else:
 node.hi = self._insert(string, node.hi)
 return node
 def _search(self, string, node):
 """Return node that string ends in.
 """
 if not string or not node:
 return node
 char, *rest = string
 if char == node.char:
 if not rest:
 return node
 return self._search(rest, node.eq)
 elif char < node.char:
 return self._search(string, node.lo)
 else:
 return self._search(string, node.hi)
 def __contains__(self, string):
 """Adds "string in TST" syntactic sugar.
 """
 node = self._search(string, self._root)
 if node:
 return node.count
 return False
 @property
 def splitchar(self):
 return self._splitchar

File language_model.py:

from collections import deque
from tst import TernarySearchTree
class ContainsEverything:
"""Dummy container that mimics containing everything.
Has .add() method to mimic set.
"""
 def __contains__(self, _):
 return True
 def add(self, _):
 pass
class LanguageModel():
 """N-gram (Markov) model that uses a ternary search tree.
 Tracks frequencies and calculates probabilities.
 Attributes
 ----------
 n : int
 Size of n-grams to be tracked.
 vocabulary : set
 If provided, n-grams containing words not in vocabulary are skipped.
 Can be other container than set, if it has add method.
 targets : container
 If provided, n-grams not ending in target are counted as
 ending in "OOV" (OutOfVocabulary) instead, so probabilities
 can still be calculated.
 boundary : str
 N-grams crossing boundary will not be counted,
 e.g. sentence </s> or document </doc> meta tags
 splitchar : str
 String that separates tokens in n-grams
 """
 def __init__(self, n, boundary="</s>", splitchar="#",
 vocabulary=None, targets=None):
 """
 Parameters
 ----------
 n : int
 Size of n-grams to be tracked.
 boundary : str
 N-grams crossing boundary will not be counted,
 e.g. sentence </s> or document </doc> meta tags
 splitchar : str
 String that separates tokens in n-grams
 vocabulary : set
 If provided, n-grams with words not in vocabulary are skipped.
 Can be other container than set, if it has add method.
 targets : container
 If provided, n-grams not ending in target are counted as
 ending in "OOV" (OutOfVocabulary) instead, so probabilities
 can still be calculated.
 """
 if not targets:
 targets = ContainsEverything()
 if not vocabulary:
 vocabulary = ContainsEverything()
 self._n = n
 self._counts = TernarySearchTree(splitchar)
 self._vocabulary = vocabulary
 self._targets = targets
 self._boundary = boundary
 self._splitchar = splitchar
 def train(self, sequence):
 """Train model on all n-grams in sequence.
 Parameters
 ----------
 sequence : iterable of str
 Sequence of tokens to train on.
 Notes
 -----
 A sequence [A, B, C, D, E] with n==3 will result in these
 n-grams:
 [A, B, C]
 [B, C, D]
 [C, D, E]
 [D, E]
 [E]
 """
 n_gram = deque(maxlen=self.n)
 for element in sequence:
 if element == self.boundary:
 # train on smaller n-grams at end of sentence
 # but exclude full n_gram if it was already trained
 # on in last iteration
 not_trained = len(n_gram) < self.n
 for length in range(1, len(n_gram) + not_trained):
 self._train(list(n_gram)[-length:])
 n_gram.clear()
 continue
 n_gram.append(element)
 if len(n_gram) == self.n:
 if element not in self.targets:
 self._train(list(n_gram)[:-1])
 continue
 self._train(n_gram)
 # train on last n-grams in sequence
 # ignore full n-gram if it has already been trained on
 if len(n_gram) == self.n:
 n_gram = list(n_gram)[1:]
 for length in range(1, len(n_gram) + 1):
 self._train(list(n_gram)[-length:])
 def probability(self, sequence):
 """Returns probability of the sequence.
 Parameters
 ----------
 sequence : iterable of str
 Sequence of tokens to get the probability for
 Returns
 -------
 float or list of float
 Probability of last element or probabilities of all elements
 """
 try:
 n_gram = sequence[-self.n:]
 # if sequence is generator (cannot slice - TypeError),
 # run through it and return probability for final element
 except TypeError:
 n_gram = deque(maxlen=self.n)
 for element in sequence:
 n_gram.append(element)
 probability = self._probability(n_gram)
 return probability
 def frequency(self, n_gram):
 """Return frequency of n_gram.
 Parameters
 ----------
 n_gram : list/tuple of str
 Returns
 -------
 int
 Frequency
 """
 n_gram_string = self.splitchar.join(n_gram)
 frequency = self._counts.frequency(n_gram_string)
 return frequency
 def _train(self, n_gram):
 # test for OOV words
 for idx, word in enumerate(n_gram):
 if word not in self.vocabulary:
 n_gram = list(n_gram)[:idx]
 n_gram_string = self.splitchar.join(n_gram)
 self._counts.insert(n_gram_string)
 def _probability(self, n_gram):
 frequency = self.frequency(n_gram)
 if frequency == 0:
 return 0
 *preceding, target = n_gram
 total = self.frequency(preceding)
 probability = frequency / total
 return probability
 def __contains__(self, n_gram):
 return n_gram in self._counts
 @property
 def n(self):
 return self._n
 @property
 def vocabulary(self):
 return self._vocabulary
 @property
 def targets(self):
 return self._targets
 @property
 def boundary(self):
 return self._boundary
 @property
 def splitchar(self):
 return self._splitchar

For testing with random strings:

import random
from string import ascii_letters
from language_model import LanguageModel
def generate_random_strings(num):
 random.seed(69)
 for i in range(num):
 length = random.choice(range(12))
 yield "".join(random.choices(ascii_letters, k=length))
lm = LanguageModel(4)
lm.train(generate_random_strings(1000000))
asked Mar 26, 2019 at 13:18
\$\endgroup\$

1 Answer 1

2
\$\begingroup\$
class ContainsEverything:
"""Dummy container that mimics containing everything.
Has .add() method to mimic set.
"""

The indentation is borked here, and needs correcting before the code will run.


I profiled with guppy3 (inlining everything into one file for my convenience - that explains the __main__ below):

lm = LanguageModel(4)
lm.train(generate_random_strings(10000))
from guppy import hpy
h = hpy(lm)
print(h.heap())

Nearly all of the memory was accounted for in the top two lines:

Partition of a set of 502128 objects. Total size = 43108362 bytes.
 Index Count % Size % Cumulative % Kind (class / dict of class)
 0 233701 47 26174512 61 26174512 61 dict of __main__.Node
 1 233701 47 13087256 30 39261768 91 __main__.Node

I'm not entirely sure where the dict of __main__.Node comes from, but clearly Node is the culprit, and each node is contributing 56 bytes inherently and a further 112 bytes indirectly.

However, looking at the use of the tree:

 self._counts = TernarySearchTree(splitchar)
 ...
 frequency = self._counts.frequency(n_gram_string)
 ...
 self._counts.insert(n_gram_string)
 ...
 return n_gram in self._counts

I can't see any reason to use a tree. There's no use of internal nodes. As far as I can see, it can easily be replaced by a Counter, whereupon the memory usage drops by 90%:

Partition of a set of 34724 objects. Total size = 4141426 bytes.
 Index Count % Size % Cumulative % Kind (class / dict of class)
 0 9838 28 884914 21 884914 21 str
 1 8509 25 613768 15 1498682 36 tuple
 2 415 1 365544 9 1864226 45 type
 3 2226 6 320544 8 2184770 53 types.CodeType
 4 4329 12 309952 7 2494722 60 bytes
 5 1 0 295024 7 2789746 67 collections.Counter
answered Aug 23, 2019 at 10:42
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.