2
\$\begingroup\$

I've implemented an AA Tree (a type of self-balancing binary search tree) using Numba, a just-in-time compiler for Python. This implementation includes rank and select operations, which are useful for order statistics.

The AA Tree is designed to maintain balance for fast operations while being simpler to implement than other self-balancing trees like Red-Black trees. I chose to use Numba to potentially improve performance, especially for the rank and select operations.

Here's my implementation:

# Order Statistics Tree based on AA Tree
from numba import njit
from numba import uint32, deferred_type, optional, uint16
from numba.experimental import jitclass
from collections import OrderedDict
node_type = deferred_type()
spec = OrderedDict()
spec['data'] = uint32
spec['left'] = optional(node_type)
spec['right'] = optional(node_type)
spec['level'] = uint16
spec['size'] = uint16
spec['left_size'] = uint16
@jitclass(spec)
class NumbaNodeAA:
 def __init__(self, data: int):
 self.data = data
 self.left = None
 self.right = None
 self.level = 1
 self.size = 1
 self.left_size = 0
node_type.define(NumbaNodeAA.class_type.instance_type)
class AAOSTN:
 def __init__(self):
 self.root = None
 def insert(self, data: int):
 self.root = self._insert(self.root, data)
 def _insert(self, node: NumbaNodeAA, data: int):
 if node is None:
 return NumbaNodeAA(data)
 if data < node.data:
 node.left = self._insert(node.left, data)
 elif data > node.data:
 node.right = self._insert(node.right, data)
 else:
 return node
 node = self._skew(node)
 node = self._split(node)
 node.size = self._update_size_static(node)
 return node
 def _skew(self, node: NumbaNodeAA):
 if node is None or node.left is None:
 return node
 if node.level == node.left.level:
 left = node.left
 node.left = left.right
 left.right = node
 node.size = self._update_size_static(node)
 left.size = self._update_size_static(left)
 return left
 return node
 def _split(self, node: NumbaNodeAA):
 if node is None or node.right is None or node.right.right is None:
 return node
 if node.level == node.right.right.level:
 right = node.right
 node.right = right.left
 right.left = node
 right.level += 1
 node.size = self._update_size_static(node)
 right.size = self._update_size_static(right)
 return right
 return node
 @staticmethod
 @njit
 def _update_size_static(node: NumbaNodeAA):
 if node is None:
 return 0
 left_size = node.left.size if node.left is not None else 0
 right_size = node.right.size if node.right is not None else 0
 node.left_size = left_size
 node.size = 1 + left_size + right_size
 return node.size
 @staticmethod
 @njit(fastmath=True)
 def _select_static(node: NumbaNodeAA, k: int):
 while node is not None:
 if k < node.left_size:
 node = node.left
 elif k > node.left_size:
 k -= node.left_size + 1
 node = node.right
 else:
 return node.data
 return -1
 @staticmethod
 @njit(fastmath=True)
 def _rank_static(node: NumbaNodeAA, data: int):
 rank = 0
 while node is not None:
 if data < node.data:
 node = node.left
 elif data > node.data:
 rank += node.left_size + 1
 node = node.right
 else:
 return rank + node.left_size
 return -1
 def select(self, k: uint32):
 return self._select_static(self.root, k)
 def rank(self, data: uint32):
 return self._rank_static(self.root, data)
 def delete(self, data: int):
 self.root = self._delete(self.root, data)
 def _delete(self, node: NumbaNodeAA, data: int):
 if node is None:
 return None
 if data < node.data:
 node.left = self._delete(node.left, data)
 elif data > node.data:
 node.right = self._delete(node.right, data)
 else:
 if node.left is None or node.right is None:
 return node.left if node.left is not None else node.right
 else:
 successor = self._find_min(node.right)
 node.data = successor.data
 node.right = self._delete(node.right, successor.data)
 node = self._decrease_level(node)
 node = self._skew(node)
 if node.right is not None:
 node.right = self._skew(node.right)
 if node.right is not None and node.right.right is not None:
 node.right.right = self._skew(node.right.right)
 node = self._split(node)
 if node.right is not None:
 node.right = self._split(node.right)
 node.size = self._update_size_static(node)
 return node
 def _decrease_level(self, node: NumbaNodeAA):
 left_level = node.left.level if node.left is not None else 0
 right_level = node.right.level if node.right is not None else 0
 level = min(left_level, right_level) + 1
 if level < node.level:
 node.level = level
 if node.right is not None and node.right.level > level:
 node.right.level = level
 return node
 def _find_min(self, node: NumbaNodeAA):
 current = node
 while current.left is not None:
 current = current.left
 return current
 def inorder(self):
 stack = []
 current = self.root
 while stack or current is not None:
 if current is not None:
 stack.append(current)
 current = current.left
 else:
 current = stack.pop()
 print(current.data, end=' ')
 current = current.right
 print()

The AAOSTN class represents the AA Tree. Key methods include:

  • insert: Adds a new node to the tree
  • delete: Removes a node from the tree
  • select: Finds the k-th smallest element in the tree
  • rank: Finds the rank of a given element in the tree

The tree maintains its balance using the skew and split operations, which are characteristic of AA Trees.

I'm particularly interested in feedback on:

  1. Overall code structure and organization
  2. Effective use of Numba (I'm new to using it)
  3. Correctness of the AA Tree implementation
  4. Performance considerations, especially for insertions
  5. Python best practices and any potential improvements

While the rank and select operations seem to perform well, I've noticed that insertions are slower than I expected. Any insights into this would be appreciated, but I welcome all forms of feedback to improve this code.

Benchmark plots:

Rank benchmark Select benchmark Insertion benchmark

J_H
41.7k3 gold badges38 silver badges157 bronze badges
asked Aug 29, 2024 at 17:19
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

module-level docstring

# Order Statistics Tree based on AA Tree

This is a good comment. It would be even better as a """docstring""".

dicts are ordered

This is a bit odd:

spec = OrderedDict()

It would have been simpler to assign spec = {}. The python language makes no guarantees about preserving order-of-insertion, but for a long time now the cPython interpreter does. If there is some porting target that needs this, it's important to # comment on that. Otherwise some future maintenance engineer is likely to simplify this.

It looks like a @dataclass would have been more appropriate for this.

spec['size'] = ...
spec['left_size'] = ...

Those are curious identifiers, which raise more questions than they answer. Consider using right_size adjacent to left_size. Or clarify that one is a "total" size.

EDIT: Subsequent code suggests we have a total_size here.

meaningful identifier

This is a perfectly nice class name:

class AAOSTN:

But it really needs a """docstring""" explaining what the alphabet soup is all about.

These are not good choices of identifier:

 @staticmethod ...
 def _update_size_static( ... ):
 ...
 @staticmethod ...
 def _select_static( ... ):
 ...
 @staticmethod ...
 def _rank_static( ... ):

Sure, they're static. But just because foo() is static doesn't mean we should call it foo_static(). Please elide those suffixes.

We eagerly compute _update_size_static() a fair amount. Consider letting some @property lazily compute it when we actually need the value.

flushing stdout

 print(current.data, end=' ')

You might want to tack on ... , flush=True)

performance

Thank you for the performance plots (even though we can't reproduce such benchmarks).

insertions are slower than I expected.

Please update the Question to include the top 20-ish consumers of CPU cycles as revealed by cProfile on your chosen workload. Absent such measurements, it's hard to tell where one should focus effort on performance improvements.

answered Aug 29, 2024 at 22:00
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.