6
\$\begingroup\$

I've made an attempt at writing a skip list in Python. I'm using NumPy to generate geometric random variables, but since that's a bit of a heavy dependency to drag around I could easily implement that myself.

I based my implementation on the basic (none of the improvements such as capping node height etc.) algorithm here.

What do you think needs improvement?

import numpy as np
class SkipList:
 def __init__(self, p=0.5):
 """
 Create a Skiplist object.
 >>>> l = Skiplist() # An empty skip list
 >>>> l = Skiplist.from_iter(zip(range(5), range(5))) # A skip list from an iterable
 """
 self.p = p
 self.head = SkipList.Node()
 self.max_height = 1
 self.__length = 0
 def from_iter(it, p=0.5):
 """
 Create a SkipList from an iterable of (Key, Value) tuples
 """
 s = SkipList(p=p)
 for k, v in it:
 s.insert(k, v)
 return s
 def __getitem__(self, key):
 curr = self.head
 for level in range(self.max_height - 1, -1, -1):
 while curr.forward[level] and curr.forward[level].key < key:
 curr = curr.forward[level]
 res = curr.forward[0]
 if res and res.key == key:
 return res.value
 else:
 raise KeyError("Key {} not found".format(key))
 def __setitem__(self, key, value):
 """
 If the key is already present, the current value will be overwritten with the new value.
 """
 new_node = None
 curr = self.head
 update = [None for _ in range(self.max_height)]
 for level in range(self.max_height - 1, -1, -1):
 while curr.forward[level] and curr.forward[level].key < key:
 curr = curr.forward[level]
 update[level] = curr
 if curr.forward[0] and curr.forward[0].key == key:
 curr.forward[0].value = value
 else:
 height = np.random.geometric(self.p)
 new_forward = [n.forward[l] for l, n in enumerate(update[:height])]
 if height > self.max_height:
 new_forward += [None for _ in range(self.max_height, height)]
 self.head.forward += [None for _ in range(self.max_height, height)]
 update += [self.head for l in range(self.max_height, height)]
 self.max_height = height
 new_node = SkipList.Node(key=key, value=value, forward=new_forward)
 for l, n in enumerate(update[:height]):
 n.forward[l] = new_node
 self.__length += 1
 def __delitem__(self, key):
 curr = self.head
 update = [None for _ in range(self.max_height)]
 for level in range(self.max_height - 1, -1, -1):
 while curr.forward[level] and curr.forward[level].key < key:
 curr = curr.forward[level]
 update[level] = curr
 del_node = curr.forward[0]
 if del_node and del_node.key == key:
 for l, f in enumerate(del_node.forward):
 update[l].forward[l] = f
 self.__length -= 1
 else:
 raise KeyError("Key {} not found".format(key))
 def items(self):
 """
 Generator in the style of dict.items
 """
 curr = self.head.forward[0]
 while curr:
 yield (curr.key, curr.value)
 curr = curr.forward[0]
 def __contains__(self, key):
 try:
 self[key]
 except KeyError:
 return False
 return True
 def __iter__(self):
 curr = self.head.forward[0]
 while curr:
 yield curr.key
 curr = curr.forward[0]
 def __len__(self):
 return self.__length
 def __eq__(self, other):
 if len(self) != len(other):
 return False
 for (k1, v1), (k2, v2) in zip(self.items(), other.items()):
 if not (k1 == k2 and v1 == v2):
 return False
 return True
 class Node:
 def __init__(self, key=None, value=None, forward=None):
 if forward is None:
 forward = [None]
 self.key = key
 self.value = value
 self.forward = forward
Jamal
35.2k13 gold badges134 silver badges238 bronze badges
asked Jan 30, 2016 at 7:40
\$\endgroup\$

1 Answer 1

2
\$\begingroup\$

Remove repetition

You have almost identical code:

def items(self):
 """
 Generator in the style of dict.items
 """
 curr = self.head.forward[0]
 while curr:
 yield (curr.key, curr.value)
 curr = curr.forward[0]
def __iter__(self):
 curr = self.head.forward[0]
 while curr:
 yield curr.key
 curr = curr.forward[0]

You may avoid the repetition writing:

def __iter__(self):
 for key, _ in self.items():
 yield key

The nested loops:

 for level in range(self.max_height - 1, -1, -1):
 while curr.forward[level] and curr.forward[level].key < key:
 curr = curr.forward[level]

Are repeated identical 3 times, extract them into a function.

Use the all built-in

You do not need a manual for loop in __eq__:

def __eq__(self, other):
 if len(self) != len(other):
 return False
 return all(self_pair == other_pair
 for self_pair, other_pair in zip(self.items(), other.items())

all and avoiding tuple unpacking is closer to how you would describe the function in English (all pairs should be equal)

You may also use and instead of a separate if

def __eq__(self, other):
 return len(self) == len(other) and \
 all(self_pair == other_pair
 for self_pair, other_pair in zip(self.items(), other.items())

It makes the code even nearer to English (The length should be equal and all pairs should be equal)

answered Jan 30, 2016 at 15:16
\$\endgroup\$
1
  • \$\begingroup\$ Thanks! I've made the changes. One thing I'm unsure about is that the function I created to replace the repated loop code builds the update list even when it's not necessary (i.e. for getitem). \$\endgroup\$ Commented Jan 31, 2016 at 3:56

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.