I've made an attempt at writing a skip list in Python. I'm using NumPy to generate geometric random variables, but since that's a bit of a heavy dependency to drag around I could easily implement that myself.
I based my implementation on the basic (none of the improvements such as capping node height etc.) algorithm here.
What do you think needs improvement?
import numpy as np
class SkipList:
def __init__(self, p=0.5):
"""
Create a Skiplist object.
>>>> l = Skiplist() # An empty skip list
>>>> l = Skiplist.from_iter(zip(range(5), range(5))) # A skip list from an iterable
"""
self.p = p
self.head = SkipList.Node()
self.max_height = 1
self.__length = 0
def from_iter(it, p=0.5):
"""
Create a SkipList from an iterable of (Key, Value) tuples
"""
s = SkipList(p=p)
for k, v in it:
s.insert(k, v)
return s
def __getitem__(self, key):
curr = self.head
for level in range(self.max_height - 1, -1, -1):
while curr.forward[level] and curr.forward[level].key < key:
curr = curr.forward[level]
res = curr.forward[0]
if res and res.key == key:
return res.value
else:
raise KeyError("Key {} not found".format(key))
def __setitem__(self, key, value):
"""
If the key is already present, the current value will be overwritten with the new value.
"""
new_node = None
curr = self.head
update = [None for _ in range(self.max_height)]
for level in range(self.max_height - 1, -1, -1):
while curr.forward[level] and curr.forward[level].key < key:
curr = curr.forward[level]
update[level] = curr
if curr.forward[0] and curr.forward[0].key == key:
curr.forward[0].value = value
else:
height = np.random.geometric(self.p)
new_forward = [n.forward[l] for l, n in enumerate(update[:height])]
if height > self.max_height:
new_forward += [None for _ in range(self.max_height, height)]
self.head.forward += [None for _ in range(self.max_height, height)]
update += [self.head for l in range(self.max_height, height)]
self.max_height = height
new_node = SkipList.Node(key=key, value=value, forward=new_forward)
for l, n in enumerate(update[:height]):
n.forward[l] = new_node
self.__length += 1
def __delitem__(self, key):
curr = self.head
update = [None for _ in range(self.max_height)]
for level in range(self.max_height - 1, -1, -1):
while curr.forward[level] and curr.forward[level].key < key:
curr = curr.forward[level]
update[level] = curr
del_node = curr.forward[0]
if del_node and del_node.key == key:
for l, f in enumerate(del_node.forward):
update[l].forward[l] = f
self.__length -= 1
else:
raise KeyError("Key {} not found".format(key))
def items(self):
"""
Generator in the style of dict.items
"""
curr = self.head.forward[0]
while curr:
yield (curr.key, curr.value)
curr = curr.forward[0]
def __contains__(self, key):
try:
self[key]
except KeyError:
return False
return True
def __iter__(self):
curr = self.head.forward[0]
while curr:
yield curr.key
curr = curr.forward[0]
def __len__(self):
return self.__length
def __eq__(self, other):
if len(self) != len(other):
return False
for (k1, v1), (k2, v2) in zip(self.items(), other.items()):
if not (k1 == k2 and v1 == v2):
return False
return True
class Node:
def __init__(self, key=None, value=None, forward=None):
if forward is None:
forward = [None]
self.key = key
self.value = value
self.forward = forward
1 Answer 1
Remove repetition
You have almost identical code:
def items(self):
"""
Generator in the style of dict.items
"""
curr = self.head.forward[0]
while curr:
yield (curr.key, curr.value)
curr = curr.forward[0]
def __iter__(self):
curr = self.head.forward[0]
while curr:
yield curr.key
curr = curr.forward[0]
You may avoid the repetition writing:
def __iter__(self):
for key, _ in self.items():
yield key
The nested loops:
for level in range(self.max_height - 1, -1, -1):
while curr.forward[level] and curr.forward[level].key < key:
curr = curr.forward[level]
Are repeated identical 3 times, extract them into a function.
Use the all
built-in
You do not need a manual for loop in __eq__
:
def __eq__(self, other):
if len(self) != len(other):
return False
return all(self_pair == other_pair
for self_pair, other_pair in zip(self.items(), other.items())
all
and avoiding tuple unpacking is closer to how you would describe the function in English (all pairs should be equal)
You may also use and
instead of a separate if
def __eq__(self, other):
return len(self) == len(other) and \
all(self_pair == other_pair
for self_pair, other_pair in zip(self.items(), other.items())
It makes the code even nearer to English (The length should be equal and all pairs should be equal)
-
\$\begingroup\$ Thanks! I've made the changes. One thing I'm unsure about is that the function I created to replace the repated loop code builds the update list even when it's not necessary (i.e. for getitem). \$\endgroup\$Davis Yoshida– Davis Yoshida2016年01月31日 03:56:42 +00:00Commented Jan 31, 2016 at 3:56
Explore related questions
See similar questions with these tags.