I've implemented an AA Tree (a type of self-balancing binary search tree) using Numba, a just-in-time compiler for Python. This implementation includes rank and select operations, which are useful for order statistics.
The AA Tree is designed to maintain balance for fast operations while being simpler to implement than other self-balancing trees like Red-Black trees. I chose to use Numba to potentially improve performance, especially for the rank and select operations.
Here's my implementation:
# Order Statistics Tree based on AA Tree
from numba import njit
from numba import uint32, deferred_type, optional, uint16
from numba.experimental import jitclass
from collections import OrderedDict
node_type = deferred_type()
spec = OrderedDict()
spec['data'] = uint32
spec['left'] = optional(node_type)
spec['right'] = optional(node_type)
spec['level'] = uint16
spec['size'] = uint16
spec['left_size'] = uint16
@jitclass(spec)
class NumbaNodeAA:
def __init__(self, data: int):
self.data = data
self.left = None
self.right = None
self.level = 1
self.size = 1
self.left_size = 0
node_type.define(NumbaNodeAA.class_type.instance_type)
class AAOSTN:
def __init__(self):
self.root = None
def insert(self, data: int):
self.root = self._insert(self.root, data)
def _insert(self, node: NumbaNodeAA, data: int):
if node is None:
return NumbaNodeAA(data)
if data < node.data:
node.left = self._insert(node.left, data)
elif data > node.data:
node.right = self._insert(node.right, data)
else:
return node
node = self._skew(node)
node = self._split(node)
node.size = self._update_size_static(node)
return node
def _skew(self, node: NumbaNodeAA):
if node is None or node.left is None:
return node
if node.level == node.left.level:
left = node.left
node.left = left.right
left.right = node
node.size = self._update_size_static(node)
left.size = self._update_size_static(left)
return left
return node
def _split(self, node: NumbaNodeAA):
if node is None or node.right is None or node.right.right is None:
return node
if node.level == node.right.right.level:
right = node.right
node.right = right.left
right.left = node
right.level += 1
node.size = self._update_size_static(node)
right.size = self._update_size_static(right)
return right
return node
@staticmethod
@njit
def _update_size_static(node: NumbaNodeAA):
if node is None:
return 0
left_size = node.left.size if node.left is not None else 0
right_size = node.right.size if node.right is not None else 0
node.left_size = left_size
node.size = 1 + left_size + right_size
return node.size
@staticmethod
@njit(fastmath=True)
def _select_static(node: NumbaNodeAA, k: int):
while node is not None:
if k < node.left_size:
node = node.left
elif k > node.left_size:
k -= node.left_size + 1
node = node.right
else:
return node.data
return -1
@staticmethod
@njit(fastmath=True)
def _rank_static(node: NumbaNodeAA, data: int):
rank = 0
while node is not None:
if data < node.data:
node = node.left
elif data > node.data:
rank += node.left_size + 1
node = node.right
else:
return rank + node.left_size
return -1
def select(self, k: uint32):
return self._select_static(self.root, k)
def rank(self, data: uint32):
return self._rank_static(self.root, data)
def delete(self, data: int):
self.root = self._delete(self.root, data)
def _delete(self, node: NumbaNodeAA, data: int):
if node is None:
return None
if data < node.data:
node.left = self._delete(node.left, data)
elif data > node.data:
node.right = self._delete(node.right, data)
else:
if node.left is None or node.right is None:
return node.left if node.left is not None else node.right
else:
successor = self._find_min(node.right)
node.data = successor.data
node.right = self._delete(node.right, successor.data)
node = self._decrease_level(node)
node = self._skew(node)
if node.right is not None:
node.right = self._skew(node.right)
if node.right is not None and node.right.right is not None:
node.right.right = self._skew(node.right.right)
node = self._split(node)
if node.right is not None:
node.right = self._split(node.right)
node.size = self._update_size_static(node)
return node
def _decrease_level(self, node: NumbaNodeAA):
left_level = node.left.level if node.left is not None else 0
right_level = node.right.level if node.right is not None else 0
level = min(left_level, right_level) + 1
if level < node.level:
node.level = level
if node.right is not None and node.right.level > level:
node.right.level = level
return node
def _find_min(self, node: NumbaNodeAA):
current = node
while current.left is not None:
current = current.left
return current
def inorder(self):
stack = []
current = self.root
while stack or current is not None:
if current is not None:
stack.append(current)
current = current.left
else:
current = stack.pop()
print(current.data, end=' ')
current = current.right
print()
The AAOSTN class represents the AA Tree. Key methods include:
- insert: Adds a new node to the tree
- delete: Removes a node from the tree
- select: Finds the k-th smallest element in the tree
- rank: Finds the rank of a given element in the tree
The tree maintains its balance using the skew and split operations, which are characteristic of AA Trees.
I'm particularly interested in feedback on:
- Overall code structure and organization
- Effective use of Numba (I'm new to using it)
- Correctness of the AA Tree implementation
- Performance considerations, especially for insertions
- Python best practices and any potential improvements
While the rank and select operations seem to perform well, I've noticed that insertions are slower than I expected. Any insights into this would be appreciated, but I welcome all forms of feedback to improve this code.
Benchmark plots:
1 Answer 1
module-level docstring
# Order Statistics Tree based on AA Tree
This is a good comment. It would be even better as a """docstring""".
dicts are ordered
This is a bit odd:
spec = OrderedDict()
It would have been simpler to assign spec = {}
.
The python language makes no guarantees about preserving order-of-insertion,
but for a long time now the cPython interpreter does.
If there is some porting target that needs this,
it's important to #
comment on that.
Otherwise some future maintenance engineer is likely to simplify this.
It looks like a @dataclass would have been more appropriate for this.
spec['size'] = ...
spec['left_size'] = ...
Those are curious identifiers, which raise more questions than they answer.
Consider using right_size
adjacent to left_size
.
Or clarify that one is a "total" size.
EDIT: Subsequent code suggests we have a total_size
here.
meaningful identifier
This is a perfectly nice class name:
class AAOSTN:
But it really needs a """docstring""" explaining what the alphabet soup is all about.
These are not good choices of identifier:
@staticmethod ...
def _update_size_static( ... ):
...
@staticmethod ...
def _select_static( ... ):
...
@staticmethod ...
def _rank_static( ... ):
Sure, they're static. But just because foo() is static doesn't mean we should call it foo_static(). Please elide those suffixes.
We eagerly compute _update_size_static() a fair amount. Consider letting some @property lazily compute it when we actually need the value.
flushing stdout
print(current.data, end=' ')
You might want to tack on ... , flush=True)
performance
Thank you for the performance plots (even though we can't reproduce such benchmarks).
insertions are slower than I expected.
Please update the Question to include the top 20-ish consumers of CPU cycles as revealed by cProfile on your chosen workload. Absent such measurements, it's hard to tell where one should focus effort on performance improvements.
Explore related questions
See similar questions with these tags.