Using Python 2.7 and here is my code which implement trie match for a whole word, and also support if the whole word contains ?
(match any one character) or *
(any zero or more characters).
My question is wondering if any improvement I can make (in terms of performance) and also if any code functional bugs, please feel free to point out.
from collections import defaultdict
class TrieNode:
def __init__(self):
self.children = defaultdict(TrieNode)
self.isEnd = False
def insert(self, source):
if not source:
return
node = self
for s in source:
node = node.children[s]
node.isEnd = True
def search(self, source):
if not source:
return self.isEnd == True
if source[0] == '?':
for (ch,node) in self.children.items():
if node.search(source[1:]) == True:
return True
return False
elif source[0] == '*':
for (ch, node) in self.children.items():
if ((node.search(source[1:]) == True) or
(node.search(source) == True)):
return True
return False
else:
if source[0] in self.children:
node = self.children[source[0]]
return node.search(source[1:])
else:
return False
if __name__ == "__main__":
root = TrieNode()
root.insert('abc')
root.insert('acd')
root.insert('bcd')
print root.search('acd') # True
print root.search('acdd') # False
print root.search('aaa') # False
print root.search('a?d') # True
print root.search('a*d') # True
print root.search('ad*') # False
print root.search('*c*') # True
print root.search('*c?') # True
print root.search('*a?') # False
1 Answer 1
1. Review
The implementation with
defaultdict
is very elegant. I like it.The code is not portable to Python 3 because of the use of the
print
statement. It would be easy to make it portable usingfrom __future__ import print_function
.Set-like data structures contain elements, so
element
would be a better variable name thansource
.It's inconvenient to create a trie from some other data structure like a list of strings: you have to call the
insert
method for each element you want to add. It would be better if the constructor took an optional iterator and inserted all the elements, just like the constructors for other collections in Python.Instead of
self.isEnd == True
, just writeself.isEnd
, and similarly for other comparisons againstTrue
.The attributes
children
andisEnd
are only intended for use by theTrieNode
class itself. It's conventional to give such attributes names starting with_
.The code for
insert
starts like this:if not source: return
which means that if you try to add the empty string to the trie, it apparently succeeds, but in fact the empty string was not added. This is misleading. If you really want to prevent the caller adding the empty string, then raise an exception:
if not source: raise ValueError("empty string not supported")
But this seems inconvenient to me, so instead I suggest removing these two lines and allowing the empty string as an element. But this reveals a problem with the
search
method:>>> root = TrieNode() >>> root.insert('') >>> root.search('*') False
To fix this, revise the
search
method as follows:elif source[0] == '*': if node.search(source[1:]): return True for (ch, node) in self.children.items(): if node.search(source) == True: return True return False
(but see below for further improvements to this code).
In this code:
for (ch,node) in self.children.items(): if node.search(source[1:]) == True: return True return False
you don't make any use of
ch
, so it would be simpler to write:for node in self.children.values(): if node.search(source[1:]): return True return False
which is the same as:
return any(node.search(source[1:]) for node in self.children.values())
Instead of computing
source[1:]
on each iteration of the loop, compute it once and remember it in a local variable:rest = source[1:] return any(node.search(rest) for node in self.children.values())
But better still, reorganize the search so that it remembers the current index into the search term, then you don't have to construct all these substrings. See below for how this might be implemented.
The interface to the
search
method makes it impossible to tell if strings containing wildcards are elements of the trie. Consider:>>> import random >>> root = TrieNode() >>> root.insert(random.choice(['?', 'x'])) >>> root.search('?') True
Is
'?'
a member of the trie or not? It would make sense to provide an alternative search method that doesn't use wildcards: for example the caller could write'element' in trie
instead oftrie.search('element')
. This requires a__contains__
method, which is straightforward to implement since it doesn't need to consider wildcards:def __contains__(self, element): node = self for k in element: if k not in node._children: return False node = node._children[k] return node._end
Now you can accurately determine membership for strings containing wildcards:
>>> '?' in root False
The
search
method only tells you whether some element in the trie matched the search term, and doesn't tell you which element or elements match. When the search term contains wildcards, it would be more useful if the search generated the matching elements. For example, if you have added the words in a dictionary to a trie, then you'd like to be able to query'p?t'
and get out'pat', 'pet', 'pit', 'pot', 'put'
. See the revised code in §2 below, and further discussion in §3.The data structure presents a set-like interface (it represents a collection of distinct strings where you can add new elements and test elements for membership). It would therefore make sense to design the interface so that it uses the same method names as Python's built-in sets, that is,
add
instead ofinsert
, and to implement more of the set interface, for example__iter__
,__len__
,__or__
,__and__
and so on.The easiest way to do this is to inherit from
collections.abc.Set
. The idea is that your class implements the__contains__
,__iter__
and__len__
methods, andcollections.abc.Set
class implements the other set methods in terms of these. See the revised code in §2 below.The code has test cases, but it's hard to check that they are correct: you have to carefully compare the output against the expected output. It would be better to get the computer to do the work here, using the features in the
unittest
module.
2. Revised code
from __future__ import print_function
from collections import defaultdict
from collections.abc import Set
class TrieNode(Set):
"""A set of strings implemented using a trie."""
def __init__(self, iterable=()):
self._children = defaultdict(TrieNode)
self._end = False
for element in iterable:
self.add(element)
def add(self, element):
node = self
for s in element:
node = node._children[s]
node._end = True
def __contains__(self, element):
node = self
for k in element:
if k not in node._children:
return False
node = node._children[k]
return node._end
def search(self, term):
"""Return the elements of the set matching the search term, which may
include wildcards ? (matching exactly one character) and *
(matching zero or more characters).
"""
results = set() # Set of elements matching search term.
element = [] # Current element reached in search.
def _search(m, node, i):
# Having just matched m, search for term[i:] starting at node.
element.append(m)
if i == len(term):
if node._end:
results.add(''.join(element))
elif term[i] == '?':
for k, child in node._children.items():
_search(k, child, i + 1)
elif term[i] == '*':
_search('', node, i + 1)
for k, child in node._children.items():
_search(k, child, i)
elif term[i] in node._children:
_search(term[i], node._children[term[i]], i + 1)
element.pop()
_search('', self, 0)
return results
def __iter__(self):
return iter(self.search('*'))
def __len__(self):
return sum(1 for _ in self)
3. Efficiency
The
Set
class requires implementations of__iter__
and__len__
so in the code above I implemented them as simply as possible to avoid making the answer too complicated. But if this is going to be used for anything serious then it would be a good idea to provide specialized implementations with better performance. The__iter__
method could be implemented like this, using the stack of iterators pattern:def __iter__(self): element = [''] stack = [iter([('', self)])] while stack: for k, node in stack[-1]: element.append(k) if node._end: yield ''.join(element) stack.append(iter(node._children.items())) break else: element.pop() stack.pop()
And the
__len__
method can be made \$O(1)\$ if you maintained in each node a count of the number of elements in the trie rooted at that node:def __init__(self, iterable=()): self._children = defaultdict(TrieNode) self._end = False self._len = 0 for element in iterable: self.add(element) def add(self, element): node = self for s in element: node = node._children[s] if not node._end: node._end = True node._len += 1 node = self for s in element: node._len += 1 node = node._children[s] def __len__(self): return self._len
I made the
search
method return a set of elements matching the search term. It would be more general if this method generated the matching elements one at a time. This would more efficiently support use cases such as finding any element matching a search term (by stopping the iteration after the first result), or determining if there is a unique element matching a search term (by stopping the iteration after the second result).In Python 3 it's easy to implement this using the
yield from
statement, but in Python 2 you have to write:for result in _search(...): yield result
which is rather verbose. But see below for an alternative.
The implementation of the
search
method is inefficient in the way it handles the*
wildcard. The problem is that when the current character in the search term is*
, the search descends the trie twice at each node (once having consumed the*
, once not). This means that each node in the trie may be visited many times in the course of the search.It is possible to revise the code so that each node in the trie is visited at most once, by modifying the search method so that instead of considering a single index
i
in the search term, it considers a set of indexes.Here's one way to implement this, again using the stack of iterators pattern:
def search(self, term): """Generate the elements of the set matching the search term, which may include wildcards. """ def _add(indexes, i): # Add i to set of indexes into term, taking account of wildcards. while i < len(term) and term[i] == '*': indexes.add(i) i += 1 indexes.add(i) indexes = set() _add(indexes, 0) if self._end and len(term) in indexes: yield '' indexes_stack = [indexes] # Stack of sets of indexes into term. element = [''] # Current element reached in search. iter_stack = [iter(self._children.items())] # Stack of iterators. while iter_stack: for k, node in iter_stack[-1]: new_indexes = set() for i in indexes_stack[-1]: if i >= len(term): continue elif term[i] == '*': _add(new_indexes, i) elif term[i] == '?' or term[i] == k: _add(new_indexes, i + 1) if new_indexes: element.append(k) if node._end and len(term) in new_indexes: yield ''.join(element) indexes_stack.append(new_indexes) iter_stack.append(iter(node._children.items())) break else: element.pop() indexes_stack.pop() iter_stack.pop()
-
\$\begingroup\$ Learned a lot from you Gareth Rees! Mark your reply as answer and vote up! \$\endgroup\$Lin Ma– Lin Ma2016年10月01日 06:54:04 +00:00Commented Oct 1, 2016 at 6:54
-
\$\begingroup\$ Would have been useful to see the full version of your final solution. That way others could have tested it more easily. \$\endgroup\$JGFMK– JGFMK2018年05月22日 05:42:34 +00:00Commented May 22, 2018 at 5:42
-
\$\begingroup\$ @JGFMK: By all means suggest an edit. \$\endgroup\$Gareth Rees– Gareth Rees2018年05月22日 08:40:04 +00:00Commented May 22, 2018 at 8:40
-
\$\begingroup\$ I would suggest having either a single source file you could copy/paste. or a link to the full blown code somewhere like here: onlinegdb.com/online_python_interpreter - where you can run the code directly. It's too complicated to work out how to fit the pieces back together especially because of the versions.. \$\endgroup\$JGFMK– JGFMK2018年05月23日 16:54:11 +00:00Commented May 23, 2018 at 16:54
Explore related questions
See similar questions with these tags.