3
\$\begingroup\$

Need reviews on my readable implementation of Trie. Also what other methods can or should be added to this data structure.

from collections import defaultdict
class Trie:
 def __init__(self):
 self.root = TrieNode()
 def add(self, word):
 curr = self.root
 for letter in word:
 node = curr.children.get(letter)
 if not node:
 node = TrieNode()
 curr.children[letter] = node
 curr = node
 curr.end_of_word = True
 def search(self, word):
 curr = self.root
 for w in word:
 node = curr.children.get(w)
 if not node:
 return False
 curr = node
 return curr.end_of_word
 def all_words_beginning_with_prefix(self, prefix):
 curr = self.root
 for letter in prefix:
 node = curr.children.get(letter)
 if curr is None:
 raise KeyError("Prefix not in Trie")
 curr = node
 result = []
 def _find(n, p):
 if n.end_of_word:
 result.append(p)
 for k, v in n.children.items():
 _find(v, p + k)
 _find(curr, prefix)
 return result
class TrieNode:
 def __init__(self):
 self.children = defaultdict(TrieNode)
 self.end_of_word = False 

This is the calling client that can be used to test the code.

if __name__ == '__main__':
 trie = Trie()
 trie.add('foobar')
 trie.add('foo')
 trie.add('bar')
 trie.add('foob')
 trie.add('foof')
 print(list(trie.all_words_beginning_with_prefix('foo')))
Gareth Rees
50.1k3 gold badges130 silver badges210 bronze badges
asked Jun 19, 2018 at 1:25
\$\endgroup\$

1 Answer 1

3
\$\begingroup\$

naming

Your naming is not consistent. sometimes you use letter, sometimes w for the same thing. I generally avoid 1-letter variable names, but if you use them, be consistent

dict.setdefault

using dict.setdefault, you can simplify your Trie.add method significantly

def add(self, word):
 curr = self.root
 for letter in word:
 curr = curr.children.setdefault(letter, TrieNode())
 curr.end_of_word = True

Then you can also change the TrieNode.children to an ordinary dict.

string representation

For debugging, it can be handy to have a string representation of a Node

def __repr__(self):
 return f'TrieNode(end_of_word={self.end_of_word}, children={tuple(self.children)})'

getting a node

Currently, there is no way in your Trie to get a node. Having this method would simplify the rest of the implementation

def __getitem__(self, word):
 curr = self.root
 for letter in word:
 curr = curr.children[letter]
 return curr
def get(self, word):
 return self[word]

If you changed the type of TrieNode.children to a dict instead of a defaultdict, this will raise a KeyError. If you left it at defaultdict, this will return an empty TrieNode: 'TrieNode(end_of_word=False, children=())', check for it, and raise the KeyError yourself

would simplify the rest of the implementation

def __getitem__(self, word):
 curr = self.root
 for letter in word:
 curr = curr.children[letter]
 if not (curr.children or curr.end_of_word):
 raise KeyError(f'{word} not in Trie')
 return curr
trie['foo']
TrieNode(end_of_word=True, children=('b', 'f'))

Search

With the method to get a Node, Search becomes as trivial as

def search(self, word):
 try:
 return self[word].end_of_word
 except KeyError:
 return False

words starting with prefix

This name can be shortened to starts_with.

Here, I would move the iteration to find the 'child-words' to the TrieNode, and recursively descend down the nodes

def child_words(self, prefix=''):
 if self.end_of_word:
 yield prefix
 for letter, node in self.children.items():
 word = prefix + letter
 yield from node.child_words(word)

Trie.starts_with becomes simply:

def starts_with(self, prefix):
 try:
 node = self[prefix]
 except KeyError:
 raise KeyError(f"Prefix `{prefix}` not in Trie")
 return node.child_words(prefix)

which returns the generator yielding words

list(trie.starts_with('foo))
['foo', 'foob', 'foobar', 'foof']

If you want to, you can even add in a inclusive boolean flag

def child_words(self, prefix='', inclusive=True):
 if inclusive and self.end_of_word:
 yield prefix
 for letter, node in self.children.items():
 word = prefix + letter
 yield from node.child_words(word, inclusive=True)

full code

class TrieNode:
 def __init__(self):
 self.children = dict()
 self.end_of_word = False
 def __repr__(self):
 return f'TrieNode(end_of_word={self.end_of_word},' \
 f' children={tuple(self.children)})'
 def child_words(self, prefix='', inclusive=True):
 if inclusive and self.end_of_word:
 yield prefix
 for letter, node in self.children.items():
 word = prefix + letter
 yield from node.child_words(word, inclusive=True)
class Trie_Maarten:
 def __init__(self):
 self.root = TrieNode()
 def add(self, word):
 curr = self.root
 for letter in word:
 curr = curr.children.setdefault(letter, TrieNode())
 curr.end_of_word = True
 def __getitem__(self, word):
 curr = self.root
 for letter in word:
 curr = curr.children[letter]
 return curr
 def get(self, word):
 return self[word]
 def search(self, word):
 try:
 return self[word].end_of_word
 except KeyError:
 return False
 def starts_with(self, prefix, inclusive=True):
 try:
 node = self[prefix]
 except KeyError:
 raise KeyError(f"Prefix `{prefix}` not in Trie")
 return node.child_words(prefix, inclusive)
answered Jun 19, 2018 at 10:45
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.