6
\$\begingroup\$

I decided to implement some data structures- this time an AVL tree. I think the logic is correct. Is there a way to make it clearer and do you have any ideas about more tests to add?

# An AVL tree, python
import random
class TreeNode:
 def __init__(self, key, val, left=None, right=None, parent=None, bal=0):
 self.key = key
 self.payload = val
 self.leftChild = left
 self.rightChild = right
 self.parent = parent
 self.balanceFactor = bal
 def update_val(self, new_val): # added
 self.payload = new_val
 def hasLeftChild(self):
 return self.leftChild
 def hasRightChild(self):
 return self.rightChild
 def isLeftChild(self):
 return self.parent and self.parent.leftChild == self
 def isRightChild(self):
 return self.parent and self.parent.rightChild == self
 def isRoot(self):
 return not self.parent
 def isLeaf(self):
 return not (self.rightChild or self.leftChild)
 def hasAnyChildren(self):
 return self.rightChild or self.leftChild
 def hasBothChildren(self):
 return self.rightChild and self.leftChild
 def replaceNodeData(self, key, value, lc, rc):
 self.key = key
 self.payload = value
 self.leftChild = lc
 self.rightChild = rc
 if self.hasLeftChild():
 self.leftChild.parent = self
 if self.hasRightChild():
 self.rightChild.parent = self
 def findSuccessor(self):
 succ = None
 if self.hasRightChild():
 succ = self.rightChild.findMin()
 else:
 if self.parent:
 if self.isLeftChild():
 succ = self.parent
 else:
 self.parent.rightChild = None
 succ = self.parent.findSuccessor()
 self.parent.rightChild = self
 return succ
 def findMin(self):
 current = self
 while current.hasLeftChild():
 current = current.leftChild
 return current
 def spliceOut(self):
 if self.isLeaf():
 if self.isLeftChild():
 self.parent.leftChild = None
 else:
 self.parent.rightChild = None
 elif self.hasAnyChildren():
 if self.hasLeftChild():
 if self.isLeftChild():
 self.parent.leftChild = self.leftChild
 else:
 self.parent.rightChild = self.leftChild
 self.leftChild.parent = self.parent
 else:
 if self.isLeftChild():
 self.parent.leftChild = self.rightChild
 else:
 self.parent.rightChild = self.rightChild
 self.rightChild.parent = self.parent
class BinarySearchTree:
 def __init__(self):
 self.root = None
 self.size = 0
 def length(self):
 return self.size
 def __len__(self):
 return self.size
 def __getitem__(self, key):
 return self.get(key)
 def __setitem__(self, k, v):
 self.put(k, v)
 def put(self, key, val):
 if self.root:
 self._put(key, val, self.root)
 else:
 self.root = TreeNode(key, val)
 self.size = self.size + 1
 def _put(self, key, val, currentNode):
 if key == currentNode.key:
 currentNode.update_val(val)
 return
 if key < currentNode.key:
 if currentNode.hasLeftChild():
 self._put(key, val, currentNode.leftChild)
 else:
 currentNode.leftChild = TreeNode(key, val, parent=currentNode)
 self.updateBalance(currentNode.leftChild)
 else:
 if currentNode.hasRightChild():
 self._put(key, val, currentNode.rightChild)
 else:
 currentNode.rightChild = TreeNode(key, val, parent=currentNode)
 self.updateBalance(currentNode.rightChild)
 def rotateLeft(self, rotRoot):
 newRoot = rotRoot.rightChild
 rotRoot.rightChild = newRoot.leftChild
 if newRoot.leftChild != None:
 newRoot.leftChild.parent = rotRoot
 newRoot.parent = rotRoot.parent
 if rotRoot.isRoot():
 self.root = newRoot
 else:
 if rotRoot.isLeftChild():
 rotRoot.parent.leftChild = newRoot
 else:
 rotRoot.parent.rightChild = newRoot
 newRoot.leftChild = rotRoot
 rotRoot.parent = newRoot
 rotRoot.balanceFactor = rotRoot.balanceFactor + 1 - min(newRoot.balanceFactor, 0)
 newRoot.balanceFactor = newRoot.balanceFactor + 1 + max(rotRoot.balanceFactor, 0)
 def rotateRight(self, rotRoot):
 newRoot = rotRoot.leftChild
 rotRoot.leftChild = newRoot.rightChild
 if newRoot.rightChild != None:
 newRoot.rightChild.parent = rotRoot
 newRoot.parent = rotRoot.parent
 if rotRoot.isRoot():
 self.root = newRoot
 else:
 if rotRoot.isLeftChild():
 rotRoot.parent.leftChild = newRoot
 else:
 rotRoot.parent.rightChild = newRoot
 newRoot.rightChild = rotRoot
 rotRoot.parent = newRoot
 rotRoot.balanceFactor = rotRoot.balanceFactor - 1 - max(newRoot.balanceFactor, 0)
 newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(0, rotRoot.balanceFactor)
 def updateBalance(self, node):
 if node.balanceFactor > 1 or node.balanceFactor < -1:
 self.rebalance(node)
 return
 if node.parent != None:
 if node.isLeftChild():
 node.parent.balanceFactor += 1
 elif node.isRightChild():
 node.parent.balanceFactor -= 1
 if node.parent.balanceFactor != 0:
 self.updateBalance(node.parent)
 def rebalance(self, node):
 if node.balanceFactor < 0:
 if node.rightChild.balanceFactor > 0:
 self.rotateRight(node.rightChild)
 self.rotateLeft(node)
 else:
 self.rotateLeft(node)
 elif node.balanceFactor > 0:
 if node.leftChild.balanceFactor < 0:
 self.rotateLeft(node.leftChild)
 self.rotateRight(node)
 else:
 self.rotateRight(node)
 def get(self, key):
 if self.root:
 res = self._get(key, self.root)
 if res:
 return res.payload
 else:
 return None
 else:
 return None
 def _get(self, key, currentNode):
 if not currentNode:
 return None
 elif currentNode.key == key:
 return currentNode
 elif key < currentNode.key:
 return self._get(key, currentNode.leftChild)
 else:
 return self._get(key, currentNode.rightChild)
 def __delitem__(self, key):
 self.delete(key)
 def delete(self, key):
 if self.size > 1:
 nodeToRemove = self._get(key, self.root)
 if nodeToRemove:
 self.remove(nodeToRemove)
 self.size = self.size - 1
 else:
 raise KeyError('Error, key not in tree')
 elif self.size == 1 and self.root.key == key:
 self.root = None
 self.size = self.size - 1
 else:
 raise KeyError('Error, key not in tree')
 def remove(self, currentNode):
 if currentNode.isLeaf(): # this is leaf
 if currentNode == currentNode.parent.leftChild:
 currentNode.parent.leftChild = None
 currentNode.parent.balanceFactor -= 1
 if currentNode.parent.balanceFactor < -1:
 self.updateBalance(currentNode.parent)
 else:
 currentNode.parent.rightChild = None
 currentNode.parent.balanceFactor += 1
 if currentNode.parent.balanceFactor > 1:
 self.updateBalance(currentNode.parent)
 elif currentNode.hasBothChildren(): # this is interior node
 succ = currentNode.findSuccessor()
 succ.spliceOut()
 if succ.isLeftChild():
 succ.parent.balanceFactor -= 1
 self.updateBalance(succ.parent)
 elif succ.isRightChild():
 succ.parent.balanceFactor += 1
 self.updateBalance(succ.parent)
 currentNode.key = succ.key
 currentNode.payload = succ.payload
 else: # this node has one child
 if currentNode.hasLeftChild():
 if currentNode.isLeftChild():
 currentNode.leftChild.parent = currentNode.parent
 currentNode.parent.leftChild = currentNode.leftChild
 currentNode.parent.balanceFactor -= 1
 self.updateBalance(currentNode.parent)
 elif currentNode.isRightChild():
 currentNode.leftChild.parent = currentNode.parent
 currentNode.parent.rightChild = currentNode.leftChild
 currentNode.parent.balanceFactor += 1
 self.updateBalance(currentNode.parent)
 else:
 currentNode.replaceNodeData(currentNode.leftChild.key,
 currentNode.leftChild.payload,
 currentNode.leftChild.leftChild,
 currentNode.leftChild.rightChild)
 else:
 if currentNode.isLeftChild():
 currentNode.rightChild.parent = currentNode.parent
 currentNode.parent.leftChild = currentNode.rightChild
 currentNode.parent.balanceFactor -= 1
 self.updateBalance(currentNode.parent)
 elif currentNode.isRightChild():
 currentNode.rightChild.parent = currentNode.parent
 currentNode.parent.rightChild = currentNode.rightChild
 currentNode.parent.balanceFactor += 1
 self.updateBalance(currentNode.parent)
 else:
 currentNode.replaceNodeData(currentNode.rightChild.key,
 currentNode.rightChild.payload,
 currentNode.rightChild.leftChild,
 currentNode.rightChild.rightChild)
# End of the Tree
# Tests helping methods
def height_node(tree_node):
 if not tree_node:
 return 0
 else:
 return 1 + max(height_node(tree_node.leftChild), height_node(tree_node.rightChild))
def is_balanced(tree_node):
 return abs(height_node(tree_node.leftChild) - height_node(tree_node.rightChild)) <= 1
def list_print(tree_node):
 def top_height(tree_node):
 if not tree_node:
 return 0
 else:
 return 1 + top_height(tree_node.parent)
 if not tree_node:
 return []
 else:
 size = height_node(tree_node.root)
 l1 = [[] for x in range(size)]
 def travel_list(current):
 if current:
 travel_list(current.leftChild)
 l1[top_height(current) - 1].append(current.key)
 travel_list(current.rightChild)
 return l1
 l = travel_list(tree_node.root)
 for x in range(len(l)):
 print(l[x])
# tests in a loop:
for f in range(100):
 mytree1 = BinarySearchTree()
 for x in range(1000):
 mytree1.put(random.randint(-10000, 10000), "a")
 for i in range(100):
 if mytree1.get(i):
 mytree1.delete(i)
 if not is_balanced(mytree1.root):
 print("not good")
 h = height_node(mytree1.root)
 print("height: ", h)
 list_print(mytree1)
 break
 del (mytree1)
print("OK")

It's based on this page: Balanced Binary Search Trees

asked Nov 20, 2016 at 18:18
\$\endgroup\$
3
  • \$\begingroup\$ I'm missing documentation strings. Replacing the key (/children) of a node should come with warnings - or checks (/restoration(s)) of order. The interface looks fat. I'd expect findMin with BST rather than node, and findMax and findPredecessor to be at least mentioned as being left out for being just symmetrical. \$\endgroup\$ Commented Nov 21, 2016 at 7:34
  • \$\begingroup\$ Thanks for remarks, will work on it. Yes, it is fat, but all these things need to be there. My thoughts about findMin, findPredecessor and findMin are the same, but the way how function remove receives it's argument (the argument type is Node) makes this design easier to implement. \$\endgroup\$ Commented Nov 21, 2016 at 10:00
  • \$\begingroup\$ Seeing the amount of code identical to the hyperlinked base, I won't post a review I put hours in (for my own benefit - tinkering with "rank balanced trees", factoring out symmetries in Python binary tree implementations). \$\endgroup\$ Commented Nov 11, 2021 at 6:53

1 Answer 1

1
\$\begingroup\$

In terms of testing, delete has some additional cases that should be tested. Knuth's "Art of Computer Programming" Vol. 3 has a lot of great information regarding AVL trees and that is why in my tests I implemented a delete from a Fibonacci Tree. You can see some tests in Java here, they should translate easily to Python.

For simplicity, you might consider changing to storing height/rank instead of balance factor. The bottom-up rebalancing for insertion is very easy to understand if you have read about rank balanced trees.

answered Nov 22, 2017 at 22:23
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.