import sys
class Node:
def __init__(self, value):
self.left = None
self.right = None
self.value = value
def __str__(self):
return f"{self.value} "
class Sum:
def __init__(self, val):
self.s = val
def getS(self):
return self.s
def update(self, val):
self.s += val
class BST:
def __init__(self):
self.root = None
def insert(self, key):
curr = self.root
parent = None
if self.root:
while curr and curr.value != key:
parent = curr
if curr.value < key:
curr = curr.right
else:
curr = curr.left
else:
self.root = Node(key)
return
if parent:
if parent.value < key:
parent.right = Node(key)
else:
parent.left = Node(key)
def delete(self, key):
pass
def _doFind(self, root, key):
if root:
if root.value == key:
return root
if root.value < key:
self._doFind(root.right, key)
else:
self._doFind(root.left, key)
def find(self, key):
self._doFind(self.root, key)
def _inorder(self, root):
if root:
self._inorder(root.left)
print(root, " ")
self._inorder(root.right)
def inorder(self):
self._inorder(self.root)
def _preorder(self, root):
if root:
print(root, " ")
self._preorder(root.left)
self._preorder(root.right)
def preorder(self):
self._preorder(self.root)
def _postorder(self, root):
if root:
self._postorder(root.left)
self._postorder(root.right)
print(root, " ")
def postorder(self):
self._postorder(self.root)
def sumRToL(self, root, s):
if root:
self.sumRToL(root.right, s)
s.update(root.value)
root.value = s.getS()
self.sumRToL(root.left, s)
def sumelementsfromRtoLinplace(self):
s = Sum(0)
self.sumRToL(self.root, s)
def validate(self, root, low, high):
# Look for iterative solutions as well, probably using some stack
return (not root) or (low <= root.value <= high and (
self.validate(root.left, low, root.value) and self.validate(root.right, root.value, high)))
def validatebst(self):
max = sys.maxsize
min = -sys.maxsize - 1
return self.validate(self.root, min, max)
def isSameTree(self, p, q):
# Task : Can a level order solve this. Any non-recursive solutions as stack depth is not reliable?
"""
Checks the value as well as topological order
:type p: Node
:type q: Node
:rtype: bool
"""
if not p and not q:
return True
elif p and q and p.value == q.value:
return self.isSameTree(p.left, q.left) and self.isSameTree(p.right, q.right)
return False
def test_main():
bst = BST()
bst.insert(1)
bst.insert(2)
bst.insert(3)
bst.insert(4)
bst.insert(5)
# bst.root.left = Node(34) # Mess up the tree
# bst.insert(2)
# bst.insert(3)
# bst.insert(4)
# bst.insert(5)
# bst.sumelementsfromRtoLinplace()
# bst.inorder()
bst1 = BST()
bst1.insert(1)
bst1.insert(2)
bst1.insert(3)
bst1.insert(4)
bst1.insert(5)
print('Same tree : ', bst.isSameTree(bst.root, bst1.root))
print("Valid Tree : ", bst.validatebst())
if __name__ == '__main__':
test_main()
P.S : I had to create the Sum
class as there's no way to share the same primitive int across stack calls as there is no pass by reference in Python. I wanted to avoid using global variables throughout.
1 Answer 1
Coding style
Python has some conventions about coding style, for example snake_case
for variables and functions etc. You can find these in pep-8
Node.__repr__
for troubleshooting, this can be handy:
def __repr__(self):
return f"Node({self.value})"
with optionally the values of the children elements too
BST.update
adding a simple method to add multiple nodes can make initialization a lot simpler:
def update(self, values):
for value in values:
self.insert(value)
It also allows you to do this immediately in the __init__
def __init__(self, values=None):
self.root = None
if values is not None:
self.update(values)
and use something like this in your tests:
bst = BST(range(5))
Node
All of the methods you prepend with an _
make more sense as methods on the Node
_xxorder
for example _inorder
:
def inorder(self):
if self.left is not None:
yield from self.left.inorder()
yield self
if self.right is not None:
yield from self.right.inorder()
and then BST.inorder
:
def inorder(self):
return self.root.inorder()
You can easily foresee a reverse iteration too (for example to find the maximum of the tree:
def inorder_reverse(self):
if self.right is not None:
yield from self.right.inorder_reverse()
yield self
if self.left is not None:
yield from self.left.inorder_reverse()
same goes for the _doFind
. Node.find
:
def find(self, key):
if self.value == key:
return self
next = self.right if self.value < key else self.left
if next is None:
return None # or raise IndexError
return next.find(key)
and BST.find
:
def find(self, key):
return self.root.find(key)
magic
methods
isSameTree
compares 2 trees. Why not name it __eq__
.
Your implementation doesn't really use seld, so it might make more sense to transfer it to Node
to compare subtrees
Node.__eq__
:
def __eq__(self, other):
if other is None:
return False
return (
self.value == other.value
and self.left == other.left
and self.right == other.right
)
BST.__eq__
:
def __eq__(self, other):
return self.root == other.root
You can easily implement the Iterator
protocol on BST
:
__iter__ = inorder
and reversed
:
__reversed__ = inorder_reverse
Sum
You don't need the Sum
class, you can just pass on a value. Also this method seems more appropriate under the Node
class:
def sumRToL(self, partial_sum=0):
if self.right is not None:
partial_sum = self.right.sumRToL(partial_sum)
self.value += partial_sum
if self.left is not None:
self.left.sumRTol(self.value)
return self.value
Using this on mutable value
s might have strange effects.
on BST
:
def sumelementsfromRtoLinplace(self):
if self.root is not None:
self.root.sumRToL()
validate
checking whether your tree is valid can become very easy via the iterator we just implemented. Using pairwise
from the itertool recipes:
def validate(self):
return all(a > b for a, b in pairwise(self)) # or self.inorder() for extra clarity
testing
These unit tests can be better done in another file, importing this file, and using one of the unit test frameworks. I'm quite happy with py.test
.
import pytest
from binary_tree import BST
def test_order():
bst = BST(range(10))
assert [item.value for item in bst.inorder()] == list(range(10))
assert [item.value for item in bst] == list(range(10))
def test_reverse():
bst = BST(range(10))
items = list(reversed(range(10)))
assert [item.value for item in bst.inorder_reverse()] == items
assert [item.value for item in reversed(bst)] == items
def test_equal():
bst1 = BST(range(5))
bst2 = BST(range(5))
bst3 = BST(range(6))
bst4 = BST(range(-3, 6))
assert bst1 == bst2
assert bst1 != bst3
assert bst3 != bst1
assert bst1 != bst4
...
total code
from general_tools.itertools_recipes import pairwise
class Node:
def __init__(self, value):
self.left: Node = None
self.right: Node = None
self.value = value
def inorder(self):
if self.left is not None:
yield from self.left.inorder()
yield self
if self.right is not None:
yield from self.right.inorder()
def inorder_reverse(self):
if self.right is not None:
yield from self.right.inorder_reverse()
yield self
if self.left is not None:
yield from self.left.inorder_reverse()
def preorder(self):
yield self
if self.left is not None:
yield from self.left.inorder()
if self.right is not None:
yield from self.right.inorder()
def postorder(self):
if self.left is not None:
yield from self.left.inorder()
if self.right is not None:
yield from self.right.inorder()
yield self
def find(self, key):
if self.value == key:
return self
next = self.right if self.value < key else self.left
if next is None:
return None # or raise IndexError
return next.find(key)
def __eq__(self, other):
if other is None:
return False
return (
self.value == other.value
and self.left == other.left
and self.right == other.right
)
def sumRToL(self, partial_sum=0):
if self.right is not None:
partial_sum = self.right.sumRToL(partial_sum)
self.value += partial_sum
if self.left is not None:
self.left.sumRTol(self.value)
def __str__(self):
return f"{self.value} "
def __repr__(self):
return f"Node({self.value})"
class BST:
def __init__(self, values=None):
self.root: Node = None
if values is not None:
self.update(values)
def insert(self, key):
if self.root is None:
self.root = Node(key)
return
curr = self.root
parent = None
while curr and curr.value != key:
parent, curr = curr, curr.right if curr.value < key else curr.left
if parent is not None:
if parent.value < key:
parent.right = Node(key)
else:
parent.left = Node(key)
def update(self, values):
for value in values:
self.insert(value)
def delete(self, key):
pass
def find(self, key):
return self.root.find(key)
def inorder(self):
return self.root.inorder()
def inorder_reverse(self):
return self.root.inorder_reverse()
def preorder(self):
return self.root.preorder()
def postorder(self):
return self.root.postorder()
def sumelementsfromRtoLinplace(self):
if self.root is not None:
self.root.sumRToL()
def validatebst(self):
return all(a > b for a, b in pairwise(self))
__iter__ = inorder
__reversed__ = inorder_reverse
def __eq__(self, other):
return self.root == other.root
-
\$\begingroup\$ Consistency & Compatibility are more important as stated in a previous answer, and as even stated in PEP 8 for backward compatibility, I personally prefer
mixedCase
overlower_case_with_underscores
as well \$\endgroup\$programmer– programmer2019年03月09日 10:56:50 +00:00Commented Mar 9, 2019 at 10:56