I decided to implement some data structures- this time an AVL tree. I think the logic is correct. Is there a way to make it clearer and do you have any ideas about more tests to add?
# An AVL tree, python
import random
class TreeNode:
def __init__(self, key, val, left=None, right=None, parent=None, bal=0):
self.key = key
self.payload = val
self.leftChild = left
self.rightChild = right
self.parent = parent
self.balanceFactor = bal
def update_val(self, new_val): # added
self.payload = new_val
def hasLeftChild(self):
return self.leftChild
def hasRightChild(self):
return self.rightChild
def isLeftChild(self):
return self.parent and self.parent.leftChild == self
def isRightChild(self):
return self.parent and self.parent.rightChild == self
def isRoot(self):
return not self.parent
def isLeaf(self):
return not (self.rightChild or self.leftChild)
def hasAnyChildren(self):
return self.rightChild or self.leftChild
def hasBothChildren(self):
return self.rightChild and self.leftChild
def replaceNodeData(self, key, value, lc, rc):
self.key = key
self.payload = value
self.leftChild = lc
self.rightChild = rc
if self.hasLeftChild():
self.leftChild.parent = self
if self.hasRightChild():
self.rightChild.parent = self
def findSuccessor(self):
succ = None
if self.hasRightChild():
succ = self.rightChild.findMin()
else:
if self.parent:
if self.isLeftChild():
succ = self.parent
else:
self.parent.rightChild = None
succ = self.parent.findSuccessor()
self.parent.rightChild = self
return succ
def findMin(self):
current = self
while current.hasLeftChild():
current = current.leftChild
return current
def spliceOut(self):
if self.isLeaf():
if self.isLeftChild():
self.parent.leftChild = None
else:
self.parent.rightChild = None
elif self.hasAnyChildren():
if self.hasLeftChild():
if self.isLeftChild():
self.parent.leftChild = self.leftChild
else:
self.parent.rightChild = self.leftChild
self.leftChild.parent = self.parent
else:
if self.isLeftChild():
self.parent.leftChild = self.rightChild
else:
self.parent.rightChild = self.rightChild
self.rightChild.parent = self.parent
class BinarySearchTree:
def __init__(self):
self.root = None
self.size = 0
def length(self):
return self.size
def __len__(self):
return self.size
def __getitem__(self, key):
return self.get(key)
def __setitem__(self, k, v):
self.put(k, v)
def put(self, key, val):
if self.root:
self._put(key, val, self.root)
else:
self.root = TreeNode(key, val)
self.size = self.size + 1
def _put(self, key, val, currentNode):
if key == currentNode.key:
currentNode.update_val(val)
return
if key < currentNode.key:
if currentNode.hasLeftChild():
self._put(key, val, currentNode.leftChild)
else:
currentNode.leftChild = TreeNode(key, val, parent=currentNode)
self.updateBalance(currentNode.leftChild)
else:
if currentNode.hasRightChild():
self._put(key, val, currentNode.rightChild)
else:
currentNode.rightChild = TreeNode(key, val, parent=currentNode)
self.updateBalance(currentNode.rightChild)
def rotateLeft(self, rotRoot):
newRoot = rotRoot.rightChild
rotRoot.rightChild = newRoot.leftChild
if newRoot.leftChild != None:
newRoot.leftChild.parent = rotRoot
newRoot.parent = rotRoot.parent
if rotRoot.isRoot():
self.root = newRoot
else:
if rotRoot.isLeftChild():
rotRoot.parent.leftChild = newRoot
else:
rotRoot.parent.rightChild = newRoot
newRoot.leftChild = rotRoot
rotRoot.parent = newRoot
rotRoot.balanceFactor = rotRoot.balanceFactor + 1 - min(newRoot.balanceFactor, 0)
newRoot.balanceFactor = newRoot.balanceFactor + 1 + max(rotRoot.balanceFactor, 0)
def rotateRight(self, rotRoot):
newRoot = rotRoot.leftChild
rotRoot.leftChild = newRoot.rightChild
if newRoot.rightChild != None:
newRoot.rightChild.parent = rotRoot
newRoot.parent = rotRoot.parent
if rotRoot.isRoot():
self.root = newRoot
else:
if rotRoot.isLeftChild():
rotRoot.parent.leftChild = newRoot
else:
rotRoot.parent.rightChild = newRoot
newRoot.rightChild = rotRoot
rotRoot.parent = newRoot
rotRoot.balanceFactor = rotRoot.balanceFactor - 1 - max(newRoot.balanceFactor, 0)
newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(0, rotRoot.balanceFactor)
def updateBalance(self, node):
if node.balanceFactor > 1 or node.balanceFactor < -1:
self.rebalance(node)
return
if node.parent != None:
if node.isLeftChild():
node.parent.balanceFactor += 1
elif node.isRightChild():
node.parent.balanceFactor -= 1
if node.parent.balanceFactor != 0:
self.updateBalance(node.parent)
def rebalance(self, node):
if node.balanceFactor < 0:
if node.rightChild.balanceFactor > 0:
self.rotateRight(node.rightChild)
self.rotateLeft(node)
else:
self.rotateLeft(node)
elif node.balanceFactor > 0:
if node.leftChild.balanceFactor < 0:
self.rotateLeft(node.leftChild)
self.rotateRight(node)
else:
self.rotateRight(node)
def get(self, key):
if self.root:
res = self._get(key, self.root)
if res:
return res.payload
else:
return None
else:
return None
def _get(self, key, currentNode):
if not currentNode:
return None
elif currentNode.key == key:
return currentNode
elif key < currentNode.key:
return self._get(key, currentNode.leftChild)
else:
return self._get(key, currentNode.rightChild)
def __delitem__(self, key):
self.delete(key)
def delete(self, key):
if self.size > 1:
nodeToRemove = self._get(key, self.root)
if nodeToRemove:
self.remove(nodeToRemove)
self.size = self.size - 1
else:
raise KeyError('Error, key not in tree')
elif self.size == 1 and self.root.key == key:
self.root = None
self.size = self.size - 1
else:
raise KeyError('Error, key not in tree')
def remove(self, currentNode):
if currentNode.isLeaf(): # this is leaf
if currentNode == currentNode.parent.leftChild:
currentNode.parent.leftChild = None
currentNode.parent.balanceFactor -= 1
if currentNode.parent.balanceFactor < -1:
self.updateBalance(currentNode.parent)
else:
currentNode.parent.rightChild = None
currentNode.parent.balanceFactor += 1
if currentNode.parent.balanceFactor > 1:
self.updateBalance(currentNode.parent)
elif currentNode.hasBothChildren(): # this is interior node
succ = currentNode.findSuccessor()
succ.spliceOut()
if succ.isLeftChild():
succ.parent.balanceFactor -= 1
self.updateBalance(succ.parent)
elif succ.isRightChild():
succ.parent.balanceFactor += 1
self.updateBalance(succ.parent)
currentNode.key = succ.key
currentNode.payload = succ.payload
else: # this node has one child
if currentNode.hasLeftChild():
if currentNode.isLeftChild():
currentNode.leftChild.parent = currentNode.parent
currentNode.parent.leftChild = currentNode.leftChild
currentNode.parent.balanceFactor -= 1
self.updateBalance(currentNode.parent)
elif currentNode.isRightChild():
currentNode.leftChild.parent = currentNode.parent
currentNode.parent.rightChild = currentNode.leftChild
currentNode.parent.balanceFactor += 1
self.updateBalance(currentNode.parent)
else:
currentNode.replaceNodeData(currentNode.leftChild.key,
currentNode.leftChild.payload,
currentNode.leftChild.leftChild,
currentNode.leftChild.rightChild)
else:
if currentNode.isLeftChild():
currentNode.rightChild.parent = currentNode.parent
currentNode.parent.leftChild = currentNode.rightChild
currentNode.parent.balanceFactor -= 1
self.updateBalance(currentNode.parent)
elif currentNode.isRightChild():
currentNode.rightChild.parent = currentNode.parent
currentNode.parent.rightChild = currentNode.rightChild
currentNode.parent.balanceFactor += 1
self.updateBalance(currentNode.parent)
else:
currentNode.replaceNodeData(currentNode.rightChild.key,
currentNode.rightChild.payload,
currentNode.rightChild.leftChild,
currentNode.rightChild.rightChild)
# End of the Tree
# Tests helping methods
def height_node(tree_node):
if not tree_node:
return 0
else:
return 1 + max(height_node(tree_node.leftChild), height_node(tree_node.rightChild))
def is_balanced(tree_node):
return abs(height_node(tree_node.leftChild) - height_node(tree_node.rightChild)) <= 1
def list_print(tree_node):
def top_height(tree_node):
if not tree_node:
return 0
else:
return 1 + top_height(tree_node.parent)
if not tree_node:
return []
else:
size = height_node(tree_node.root)
l1 = [[] for x in range(size)]
def travel_list(current):
if current:
travel_list(current.leftChild)
l1[top_height(current) - 1].append(current.key)
travel_list(current.rightChild)
return l1
l = travel_list(tree_node.root)
for x in range(len(l)):
print(l[x])
# tests in a loop:
for f in range(100):
mytree1 = BinarySearchTree()
for x in range(1000):
mytree1.put(random.randint(-10000, 10000), "a")
for i in range(100):
if mytree1.get(i):
mytree1.delete(i)
if not is_balanced(mytree1.root):
print("not good")
h = height_node(mytree1.root)
print("height: ", h)
list_print(mytree1)
break
del (mytree1)
print("OK")
It's based on this page: Balanced Binary Search Trees
1 Answer 1
In terms of testing, delete has some additional cases that should be tested. Knuth's "Art of Computer Programming" Vol. 3 has a lot of great information regarding AVL trees and that is why in my tests I implemented a delete from a Fibonacci Tree. You can see some tests in Java here, they should translate easily to Python.
For simplicity, you might consider changing to storing height/rank instead of balance factor. The bottom-up rebalancing for insertion is very easy to understand if you have read about rank balanced trees.
Explore related questions
See similar questions with these tags.
findMin
with BST rather than node, andfindMax
andfindPredecessor
to be at least mentioned as being left out for being just symmetrical. \$\endgroup\$