I developed a heap sort variant that sorts a heap through traversal. Unlike in the standard heap sort, the algorithm does not remove the min element from the heap instead it traverses all the nodes of the heap visiting those with smaller values first and adding them to the sorted list
Here’s the implementation of the algorithm in python:
import heapq
def heap_sort(lst):
"""
Performs a heap sort iteratively by traversing all nodes of the heap, visiting nodes with smaller values first and adding the values to sorted result list.
Args:
lst (list): The input list to be sorted.
Returns:
list: A sorted list containing elements of the input list in ascending order.
Approach:
- Convert the input list into a min-heap
- Traverse and process the heap iteratively:
- Maintain a set to track indices that have already been processed.
- Replace the value at the current node with the value of either the left child, right child or parent, depending on specific conditions.
- Accumulate results in a separate list as each node is processed.
"""
def left_child(i):
"""
Returns the value of the left child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 1 >= len(lst) else lst[2 * i + 1]
def right_child(i):
"""
Returns the value of the right child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 2 >= len(lst) else lst[2 * i + 2]
def parent(i):
"""
Returns the value of parent of the node at index i, or infinity if the node is the root.
"""
return lst[(i - 1) // 2] if i > 0 else float("inf")
heapq.heapify(lst) # Build a min-heap from input list
# A set to keep track of visited indices (nodes)
visited_indices = set()
# List to store the sorted result
result = []
# Start traversal from the root of the heap
current_index = 0
while len(result) < len(lst):
if current_index not in visited_indices:
# Add the current node's value to the result and mark it as visited
result.append(lst[current_index])
visited_indices.add(current_index)
# Replace the current node value with value of either left, right or parent node
if parent(current_index) < min(left_child(current_index), right_child(current_index)):
lst[current_index] = min(left_child(current_index), right_child(current_index))
current_index = (current_index - 1) // 2 # Move to the parent node
elif left_child(current_index) < right_child(current_index):
lst[current_index] = min(right_child(current_index), parent(current_index))
current_index = 2 * current_index + 1 # Move to the left child
else:
lst[current_index] = min(left_child(current_index), parent(current_index))
current_index = 2 * current_index + 2 # Move to the right child
return result
The algorithm's main steps are:
Convert the input list into a min-heap using
heapq.heapify
.Initialize result to an empty list [ ]
Traverse the heap iteratively starting with root node and replace the current node's value with either the value of parent, left child or right child based on specific conditions. The traversal logic ensures that the parent node "calls" the smaller of it's children to explore its subtree until a node with value greater than the parent is encountered. If the children of the current node are all greater than the value of the parent of current node, it means the parent node has nodes in its other subtree that have smaller values. In this case, control is returned to parent node
Append the node’s value to the result when it is first visited.
My Questions
- How does the algorithm compare to the standard heap sort in terms of efficiency?
- Can it be optimized?
- Could the iterative traversal logic be simplified while maintaining correctness?
For those interested in a recursive solution, here's the implementation:
import heapq
def heap_sort_re(lst):
def left_child(i):
return float("inf") if 2 * i + 1 >= len(lst) else lst[2 * i + 1]
def right_child(i):
return float("inf") if 2 * i + 2 >= len(lst) else lst[2 * i + 2]
def parent(i):
return lst[(i - 1) // 2] if i > 0 else float("inf")
def _recurse(i = 0):
if len(result) == len(lst):
return
if i not in visited:
result.append(lst[i])
visited.add(i)
if parent(i) < min(left_child(i), right_child(i)):
lst[i] = min(left_child(i), right_child(i))
_recurse((i - 1) // 2)
elif left_child(i) < right_child(i):
lst[i] = min(right_child(i), parent(i))
_recurse( 2*i + 1)
else:
lst[i] = min(left_child(i), parent(i))
_recurse(2*i + 2)
heapq.heapify(lst)
visited = set()
result = []
_recurse()
return result
4 Answers 4
conventional names
current_index = 0
This name is not great.
Prefer j
(so we don't confuse it with the various i
parameters).
The OP name appears everywhere, so it is tediously distracting
to pronounce those four syllables.
composition
heapq.heapify(lst) # Build a min-heap from input list
This is an expensive operation, and could easily be invoked by the caller. Consider making "heapify" the caller's responsibility. Possibly one could make multiple calls to heap_sort() or a similar routine, which exploit a pre-ordered list without need of the heapify overhead.
motivation
It does not appear the proposed heap_sort(lst)
offers any
advantage over sorted(lst)
.
Traditionally the review questions we ask about code at Pull Request time are:
- Is it correct?
- Is it maintainable?
I'm willing to believe the code is correct, despite the lack of automated
unit tests.
But it absolutely fails the "maintainability" test,
given that every single team member on a project will
already be quite familiar with sorted(lst)
,
and its performance, and will be confident in its correctness.
Even if timsort was implemented as python bytecode,
I still wouldn't pit some newly minted untuned code against it.
And cPython's timsort is native C code.
I haven't run the OP code, but I am confident timsort is quicker than it is.
If a team member advocated merging such a pull request, I would not approve it until we saw some motivation, some feature offered by the code which the builtin sort does not supply.
nsmallest
Given a heapified list, I would expect idiomatic python source to simply execute
from heapq import nsmallest
...
return nsmallest(len(lst), lst) # same as `sorted(lst)`
black
The python community does not use an indent level of 2,
because indent has greater importance to python source code
than it does to various children of Algol such as C and JS.
A too-small indent level makes the source less readable to humans.
Run $ black *.py
before submitting a PR, so indent details
won't be a distraction.
caller surprise
It's considered a bit rude to trash the caller's input without
mentioning that in the docstring, e.g. lst.sort()
or shuffle(lst)
.
Either make a .copy()
, or describe the side effect in the docstring.
strictly increasing
There are several <
comparisons where it is not completely
clear we prefer that over <=
.
Citing a reference or giving a more complete specification
would help reviewers to resolve such ambiguities.
invariants
The main while
loop tries to preserve several invariants,
but it's not completely obvious what they all are.
At least one of them is that "lst shall always be monotonic".
And we have a nice loop variant, that "either result length increases to lst length, or something something" which is related to \$\log n\$ traversals. But it wasn't immediately obvious to me what the something something details would be, and that makes the "this is correct!" argument a bit harder.
Each of the three branches assigns a new value to j
(to current_index
).
But it's not obvious what promise that assignment makes.
It's worth a #
comment,
or better, an assert
.
Documenting the lst
ordering caller shall see after the call
would be a good start.
@user555045 offers a nice plot of several sort timings. I add a few more, with minor refactor. The \$- 1\$ term is to defeat an .nsmallest() optimization, which would otherwise special case to just use sorted(), explaining the suspicious resemblance between those two curves before introducing the \$- 1\$ tweak.
We see that sorted() dominates, and sort_using_heap() is a strong competitor.
...
def nsmallest(arr):
heapq.heapify(arr)
return heapq.nsmallest(len(arr) - 1, arr) # tweak to deoptimize
def sample(n, fn):
data = makedata(n)
return timeit.timeit(lambda: fn(data[:]), number=10)
def plot(ax, a, fn, color):
ax.plot([sample(n, fn) for n in a], color=color, label=fn.__name__)
def main() -> None:
a = [pow(2, i) for i in range(1, 19)]
fig = plt.figure()
ax = fig.add_subplot(2, 1, 1)
for color, fn in [
("blue", heap_sort_custom),
("red", heapsort),
("green", sort_using_heap),
("purple", nsmallest),
("orange", sorted),
]:
plot(ax, a, fn, color)
ax.set_yscale("log")
ax.set_xscale("log")
plt.legend()
plt.show()
-
\$\begingroup\$ I agree it definitely doesn't make sense to prefer untested code like this to using the built in sorted function that's well tested for correctness and optimized for performance. My motivation for posting here is first it's a new approach, I haven't seen any algorithm that simply traverses a heap to obtain a sorted list. And two given the community here, it's possible a good usecase can be found . \$\endgroup\$ariko stephen– ariko stephen2024年12月28日 07:19:43 +00:00Commented Dec 28, 2024 at 7:19
-
\$\begingroup\$ (Those graphs differ somewhat from the ones I got (/get after trying "deoptimisation"). They agree on heap_sort_custom & heapsort, but (deoptimised)nsmallest and sorted almost coincide here(3.12). I let it run up to 8M entries; past L3 size heapsort & sort_using_heap converge. The gap between heap_sort_custom and sorted seems to shrink from about 130:1 to 100:1.) \$\endgroup\$greybeard– greybeard2024年12月29日 09:17:36 +00:00Commented Dec 29, 2024 at 9:17
heap_sort(['a', 'short', 'test'])
fails, the domain of the function is one more thing I'd prefer the docstring to state.
Following the Style Guide for Python Code more closely has the overall advantage of meeting the expectations of Python coders.
One advantage at SE would be to see most of docstrings, statements and comments without horizontal scrolling.
How does the algorithm compare to the standard heap sort in terms of efficiency?
I ran some tests, plotted below. Y is time in seconds, X is the size of the array:
Note the log-log axes, heap_sort_custom
is usually ~100x as slow as sort_using_heap
, except for tiny arrays.
Full code below. This is not a rewrite, just a test, there are no modifications to heap_sort_custom
. heapsort
is just some arbitrary implementation of heapsort, the details don't matter. sort_using_heap
uses a heap to sort, but isn't heapsort. It's a lot better than I expected, or heapsort
is a lot worse, either way I didn't expect a large margin between them but there is - but whatever, both of them are only on the graph to provide some point of reference.
import heapq
import random
import timeit
import matplotlib.pyplot as plt
import numpy as np
def heap_sort_custom(lst):
"""
Performs a custom heap sort iterative algorithm without changing the size of the heap. Unlike in standard heap sort, no extractions are performed
Args:
lst (list): The input list to be sorted.
Returns:
list: A sorted list containing elements of the input list in ascending order.
Approach:
- Convert the input list into a min-heap
- Traverse and process the heap iteratively:
- Maintain a set to track indices that have already been processed.
- Replace the value at the current node with the value of either the left child, right child or parent, depending on specific conditions.
- Accumulate results in a separate list as each node is processed.
"""
def left_child(i):
"""
Returns the value of the left child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 1 >= len(lst) else lst[2 * i + 1]
def right_child(i):
"""
Returns the value of the right child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 2 >= len(lst) else lst[2 * i + 2]
def parent(i):
"""
Returns the value of parent of the node at index i, or infinity if the node is the root.
"""
return lst[(i - 1) // 2] if i > 0 else float("inf")
heapq.heapify(lst) #Build a min-heap from input list
# A set to keep track of visited indices
visited_indices = set()
# List to store the sorted result
sorted_result = []
# Start traversal from the root of the heap
current_index = 0
while len(sorted_result) < len(lst):
if current_index not in visited_indices:
# Add the current node's value to the result and mark it as visited
sorted_result.append(lst[current_index])
visited_indices.add(current_index)
# Replace the current node value with value of either left, right or parent node
if parent(current_index) < min(left_child(current_index), right_child(current_index)):
lst[current_index] = min(left_child(current_index), right_child(current_index))
current_index = (current_index - 1) // 2 # Move to the parent node
elif left_child(current_index) < right_child(current_index):
lst[current_index] = min(right_child(current_index), parent(current_index))
current_index = 2 * current_index + 1 # Move to the left child
else:
lst[current_index] = min(left_child(current_index), parent(current_index))
current_index = 2 * current_index + 2 # Move to the right child
return sorted_result
def sort_using_heap(arr):
heapq.heapify(arr)
result = []
while arr:
result.append(heapq.heappop(arr))
return result
def heapsort(arr):
def sift_down(arr, n, i):
elem = arr[i]
while True:
l = 2 * i + 1
if l >= n:
arr[i] = elem
return
r = 2 * i + 2
c = l
if r < n and arr[l] < arr[r]:
c = r
if elem >= arr[c]:
arr[i] = elem
return
arr[i] = arr[c]
i = c
n = len(arr)
for i in range(n // 2, -1, -1):
sift_down(arr, n, i)
for i in range(n - 1, 0, -1):
t = arr[i]
arr[i] = arr[0]
arr[0] = t
sift_down(arr, i, 0)
return arr
def makedata(n):
res = list(range(n))
random.seed(a=n)
random.shuffle(res)
return res
def sample1(n):
data = makedata(n)
return timeit.timeit(lambda: heap_sort_custom(data[:]), number=10)
def sample2(n):
data = makedata(n)
return timeit.timeit(lambda: heapsort(data[:]), number=10)
def sample3(n):
data = makedata(n)
return timeit.timeit(lambda: sort_using_heap(data[:]), number=10)
a = [pow(2, i) for i in range(1,17)]
fig = plt.figure()
ax = fig.add_subplot(2, 1, 1)
ax.plot(a, [sample1(n)/10 for n in a], color='blue', label='heap_sort_custom')
ax.plot(a, [sample2(n)/10 for n in a], color='red', label='heapsort')
ax.plot(a, [sample3(n)/10 for n in a], color='green', label='sort_using_heap')
ax.set_yscale('log')
ax.set_xscale('log')
plt.legend()
plt.show()
-
\$\begingroup\$ You picked up some advice. I tried
heapq.heapify()
&nlargest()
. Andsorted()
with a not entirely expected result. \$\endgroup\$greybeard– greybeard2024年12月28日 19:27:25 +00:00Commented Dec 28, 2024 at 19:27
I believe the reason for the the slowness of heap_sort_custom is mostly in the number of iterations it takes to visit all nodes.
Below is a plot of array size vs iteration count for heap_sort_custom and heapsort.
As seen in the plot, heap_sort_custom
makes at least twice as many iterations as heapsort
. The key to optimizing heap_sort_custom
is in optimizing heap traversal.
I have modified heap_sort_custom
and heapsort
to count and return the total number of iterations to sort heap. Note that in heapsort
, the iteration count does not include those during heap build.
import heapq
import random
import matplotlib.pyplot as plt
def heap_sort_custom(lst):
"""
Performs a custom heap sort iterative algorithm without changing the size of the heap. Unlike in standard heap sort, no extractions are performed
Args:
lst (list): The input list to be sorted.
Returns:
list: A sorted list containing elements of the input list in ascending order.
Approach:
- Convert the input list into a min-heap
- Traverse and process the heap iteratively:
- Maintain a set to track indices that have already been processed.
- Replace the value at the current node with the value of either the left child, right child or parent, depending on specific conditions.
- Accumulate results in a separate list as each node is processed.
"""
def left_child(i):
"""
Returns the value of the left child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 1 >= len(lst) else lst[2 * i + 1]
def right_child(i):
"""
Returns the value of the right child of the node at index i, or infinity if out of bounds.
"""
return float("inf") if 2 * i + 2 >= len(lst) else lst[2 * i + 2]
def parent(i):
"""
Returns the value of parent of the node at index i, or infinity if the node is the root.
"""
return lst[(i - 1) // 2] if i > 0 else float("inf")
heapq.heapify(lst) # Build a min-heap from input list
# A set to keep track of visited indices
visited_indices = set()
# List to store the sorted result
sorted_result = []
# Start traversal from the root of the heap
current_index = 0
#total number of while loop iterations
iteration_count = 0
while len(sorted_result) < len(lst):
iteration_count += 1
if not current_index in visited_indices:
# Add the current node's value to the result and mark it as visited
sorted_result.append(lst[current_index])
visited_indices.add(current_index)
# Replace the current node value with value of either left, right or parent node
if parent(current_index) < min(left_child(current_index), right_child(current_index)):
lst[current_index] = min(left_child(current_index), right_child(current_index))
current_index = (current_index - 1) // 2 # Move to the parent node
elif left_child(current_index) < right_child(current_index):
lst[current_index] = min(right_child(current_index), parent(current_index))
current_index = 2 * current_index + 1 # Move to the left child
else:
lst[current_index] = min(left_child(current_index), parent(current_index))
current_index = 2 * current_index + 2 # Move to the right child
return iteration_count
def heapsort(arr):
def sift_down(arr, n, i, build = True):
elem = arr[i]
nonlocal iteration_count
while True:
if not build:
iteration_count += 1
l = 2 * i + 1
if l >= n:
arr[i] = elem
return
r = 2 * i + 2
c = l
if r < n and arr[l] < arr[r]:
c = r
if elem >= arr[c]:
arr[i] = elem
return
arr[i] = arr[c]
i = c
# total number of sift_down while loop iterations during sorting phase
# doesn't include iterations during heap build
iteration_count = 0
n = len(arr)
for i in range(n // 2, -1, -1):
sift_down(arr, n, i)
for i in range(n - 1, 0, -1):
t = arr[i]
arr[i] = arr[0]
arr[0] = t
sift_down(arr, i, 0, build = False)
return iteration_count
def makedata(n):
res = list(range(n))
random.seed(a=n)
random.shuffle(res)
return res
def sample1(n):
data = makedata(n)
return heap_sort_custom(data[:])
def sample2(n):
data = makedata(n)
return heapsort(data[:])
a = [pow(2, i) for i in range(1,18)]
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(a, [sample1(n) for n in a], color='blue', label='heap_sort_custom')
ax.plot(a, [sample2(n) for n in a], color='red', label='heapsort')
ax.set_xlabel('Array Size')
ax.set_ylabel('Iteration Count')
ax.set_title('Comparison of Iteration Counts')
plt.legend()
plt.show()
Explore related questions
See similar questions with these tags.
sorted(lst)
, I am not yet seeing it. Maybe you wanted to add a "reinventing the wheel" tag, to show that this is strictly a learning exercise, not destined for production code? \$\endgroup\$nsmallest(len(lst), lst)
. \$\endgroup\$