I'd like to know what you think of my insertion sort version. I tried to be pythonic and avoid while
loops with "ugly" index-management:
def sort(a):
for i, x in enumerate(a):
for j, y in enumerate(a):
if y >= x:
a.insert(j, a.pop(i))
break
On lists of 1000 random numbers, it seems to be about four times faster (23 ms vs 96 ms) than the implementation from the top-voted answer for the top result for searching [python] insertion sort.
Benchmark code:
from random import random
from time import perf_counter as timer
from statistics import median
n = 1000
repeat = 50
def insertionSort(lst):
for index in range(1, len(lst)):
currentvalue = lst[index]
position = index
while position > 0 and lst[position - 1] > currentvalue:
lst[position] = lst[position - 1]
position = position - 1
lst[position] = currentvalue
def sort(a):
for i, x in enumerate(a):
for j, y in enumerate(a):
if y >= x:
a.insert(j, a.pop(i))
break
solutions = insertionSort, sort
for r in range(1, 6):
print('Round %d:' % r, end='')
a = [random() for _ in range(n)]
for solution in solutions:
times = []
for _ in range(repeat):
copy = a.copy()
t0 = timer()
solution(copy)
times.append(timer() - t0)
assert copy == sorted(a)
print(' %6.2f ms' % (median(times) * 1e3), end='')
print()
3 Answers 3
Code review
You should use better variable names then a
, x
and y
. But otherwise since your code works your code is fine.
Performance
It seems a couple of users are confused why Python has strange performance.
Enumerate vs index
This is pretty simple both enumerate
and index
have the same time and space complexity.
If we have a list of False
and set one to True
and want to find the index of that, both will run in \$O(n)\$ time. It may seem like enumerate
is \$O(1)\$ however it is the same as zip(range(len(sequence)), sequence)
and we know range
is \$O(n)\$.
The difference in speed that we can see is because index
is just faster than enumerate
.
Your insertion sort vs Grajdeanu Alex's
This comes down to Python being slower than C. If we look at the core of Grajdeanu's solution:
currentvalue = lst[index] position = index while position > 0 and lst[position - 1] > currentvalue: lst[position] = lst[position - 1] position = position - 1 lst[position] = currentvalue
This is doing two things:
Finding the index to stop iterating to:
while position > 0 and lst[position - 1] > currentvalue:
Performing an optimized version of
insert
andpop
. This is as they only touch a subset of the array, butinsert
andpop
touch the entire array, worst case. (Python lists are arrays in the backend.)
If you were to translate Grajdeanu Alex's solution into C the code would out perform your insert
and pop
.
Bisecting
There's a nice property about insertion sort, as you're iterating through the data everything before your index is sorted. This means we can use a better algorithm to find where to insert into.
We can use the strategy you use in the Guess a Number Between 1-100. By halving the amount of the list we have to search each check we can find where to insert into in \$O(\log(n))\$ time. This is faster than than the \$O(n)\$ that your enumerate
and Grajdeanu's algorithms are running in.
There is a library for this, bisect
, and most of the legwork is in C too, so it's nice and fast.
My timings
My code to get the timings:
import time
import math
import random
import copy
import bisect
import matplotlib.pyplot as plt
import numpy as np
from graphtimer import flat, Plotter, TimerNamespace
class Iteration(TimerNamespace):
def test_baseline(data):
pass
def test_iterate(data):
for value in data:
pass
def test_enumerate_list(data):
for item in list(enumerate(data)):
pass
def test_enumerate_partial(data):
for item in enumerate(data):
pass
def test_enumerate(data):
for i, value in enumerate(data):
pass
class Insertion(TimerNamespace):
def test_baseline(data, i, value_i, j, value_j):
pass
def test_plain(data, i, value_i, j, value_j):
data.insert(j, data.pop(i))
def test_index(data, i, value_i, j, value_j):
data.insert(data.index(value_j), data.pop(i))
def test_python(data, i, value_i, j, value_j):
while i < j:
data[j] = data[j - 1]
j -= 1
data[j] = value_i
class Joined(TimerNamespace):
def test_enumerate_plain(data, i, value_i, j, value_j):
for j, value_j in enumerate(data):
if value_i <= value_j:
data.insert(j, data.pop(i))
def test_enumerate_index(data, i, value_i, j, value_j):
for j, value_j in enumerate(data):
if value_i <= value_j:
data.insert(data.index(value_j), data.pop(i))
def test_iterate_index(data, i, value_i, j, value_j):
for value_j in data:
if value_i <= value_j:
data.insert(data.index(value_j), data.pop(i))
break
class Sorts(TimerNamespace):
def test_manuel_base(a):
for i, x in enumerate(a):
for j, y in enumerate(a):
if y >= x:
a.insert(j, a.pop(i))
break
def test_manuel_insert(a):
for i, x in enumerate(a):
for y in a:
if y >= x:
a.insert(a.index(y), a.pop(i))
break
def test_other(lst):
for index in range(1, len(lst)):
currentvalue = lst[index]
position = index
while position > 0 and lst[position - 1] > currentvalue:
lst[position] = lst[position - 1]
position = position - 1
lst[position] = currentvalue
def test_peilon(lst):
output = []
for item in lst:
bisect.insort(output, item)
memoize = {}
def create_args(size, *, _i):
size = int(size)
key = size, _i
if key in memoize:
return copy.deepcopy(memoize[key])
array = random_array(size)
j = random.randrange(0, size)
array[:j] = sorted(array[:j])
i = 0
while array[i] < array[j]:
i += 1
output = array, i, array[i], j, array[j]
memoize[key] = output
return output
def random_array(size):
array = list(range(int(size)))
random.shuffle(array)
return array
def main():
fig, axs = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, subplot_kw=dict(xscale='log', yscale='log'))
axis = [
(Iteration, {'args_conv': lambda i: [None]*int(i)}),
(Insertion, {'args_conv': create_args, 'stmt': 'fn(args[0].copy(), *args[1:])'}),
(Joined, {'args_conv': create_args, 'stmt': 'fn(args[0].copy(), *args[1:])'}),
(Sorts, {'args_conv': random_array, 'stmt': 'fn(args[0].copy(), *args[1:])'}),
]
for graph, (plot, kwargs) in zip(iter(flat(axs)), axis):
(
Plotter(plot)
.repeat(10, 2, np.logspace(1, 4), **kwargs)
.min()
.plot(graph, title=plot.__name__)
)
plt.show()
if __name__ == '__main__':
main()
(click to expand)
enter image description here
Iteration
test_baseline
The timings are flat as they are the time it takes to run the test suit. When determining the performance of each function we need to see how far away from the baseline it is.test_enumerate
&test_enumerate_partial
These are roughly the same and so we know thatenumerate
, and not tuple unpacking, is the factor at play that is taking up a lot of performance.test_enumerate
,test_enumerate_list
&test_iterate
We can see adding more \$O(n)\$ operations makes the code slower. Howeverenumerate
is a pretty slow function.
In all enumerate
is slow.
Insertion
test_baseline
Since we are copying the data in the test suit we see that at times the other functions are running the fastest that they can.This is to be expected as we are running tests on a partially sorted array. Ranging from no sort to fully sorted.
test_plain
We can see thatdata.insert(j, data.pop(i))
is really fast and is consistently aroundtest_baseline
. This means ifenumerate
was faster thandata.index
then the other answer would not be true.test_index
&test_python
From the areas we can see that optimized Python runs significantly slower than Python's C methods.This is to be expected, Python is slow.
Joined
These merge the above two together to show the impact of the difference in timings. These are a single insertion of a full insertion sort.
Unsurprisingly given the previous timings test_enumerate_plain
is by far the slowest.
Sorts
This shows that whilst your changes are fast, my answer from '17 is a pretty darn fast insertion sort.
Complexity vs Performance
It should be apparent that in Python these are two entirely different metrics. Complexity is more important when playing on a level playing field, which isn't the case in Python.
But just because Python isn't a level playing field doesn't make it useless. When programming if you try to get the best performance complexity wise then you'll have a good baseline to then optimize from. From here you can then focus on performance which is harder to reason with and harder to compare. And worst case converting the code into C will be far easier.
-
1\$\begingroup\$ Comments are not for extended discussions. This conversation has been moved to chat \$\endgroup\$Mathieu Guindon– Mathieu Guindon2020年09月16日 00:02:07 +00:00Commented Sep 16, 2020 at 0:02
In which we defend the honor of enumerate()
Although I learned from and appreciated the write-up by Peilonrayz, I was not convinced by all of the characterizations. Also, I had some specific questions not covered in those benchmarks, so I explored on my own using the script below. These notes cover a few things I learned and reframe the discussion a bit.
enumerate()
itself is not slow. Merely invoking the enumerate()
callable
is an O(1)
operation, because it does nothing with the underlying iterable of
values other than store an iterator created from the original iterable.
Is consuming an iterable via enumerate()
slow? That depends on what the
alternative is. Compared to direct iteration (for x in xs
), yes it's slower
and the magnitude of the slowdown is not trivial. But we use enumerate()
for
a reason: we need the indexes too. In that context, there are three obvious
alternatives: manage the index yourself (i += 1
), use range()
for iteration
and then obtain the value by via get-item (x = xs[i]
), or ask Python to
compute the index (i = xs.index(x)
). Compared to those alternatives,
enumerate()
is quite good: it's a little faster than managing the index
yourself or using range()
, and it is substantially faster than using
list.index()
every time. In that light, saying that "index()
is just faster
than enumerate()
" seems not quite right -- but perhaps I misunderstood or
there are errors in my findings.
Should you worry about tuple unpacking when using enumerate()
. No, it adds
almost nothing. And especially don't avoid enumerate()
on performance grounds
if it forces you to use get-item on the tuple (i = x[0]
), because that is
slower than direct unpacking.
Some evidence. The numbers below are for a run of the script with
--count=1000
(how many numbers to be sorted) and --trials=100
(how many times
did we measure to get the statistics). The output here just adds up the total
of the times for all trials (--stat=total
), but you can also run the code to
see mean, min, and max as well (those results tell similar stories). For each
function, the table shows both a scaled value (2nd column) and the raw value
(3rd column). The scaled values are easier to compare because they are
expressed as a ratio relative to the minimum value in that column. The comment
column has a schematic summary of function's behavior.
# Just calling enumerate().
# Nothing slow here: O(1).
enumerate_call_baseline : 1.0 : 0.000018 # it = None
enumerate_call : 2.0 : 0.000035 # it = enumerate()
# Direct Python iteration.
# If you need an index, don't use xs.index(x) as a general approach.
iterate_baseline : 38.4 : 0.000678 # for x in xs: pass
iterate_with_index : 190.0 : 0.003351 # for x in xs: i += 1
iterate_range_getitem : 198.6 : 0.458601 # for i in range(len(xs)): x = xs[i]
iterate_get_index : 24850.3 : 0.438433 # for x in xs: i = xs.index(x)
# Iteration with enumerate().
# Slow only when compared to a no-op for loop.
# If you need the indexes, use enumerate().
enumerate_consume : 155.6 : 0.002746 # for x in it: pass
enumerate_consume_unpack : 157.4 : 0.002778 # for i, x in it: pass
enumerate_consume_getitem : 263.8 : 0.005475 # for x in it: x[0]
Sometimes index()
is faster. Here are the benchmarks for the sorting
functions we have discussed. As others have reported, the classic compare-swap
stategy is worse than those relying on the insert-index-pop family of methods.
sort_baseline : 1.0 : 0.007389 # xs.sort()
sort_classic_swap : 618.4 : 4.569107 # classic compare-swap
sort_insert_index_pop : 122.5 : 0.905445 # xs.insert(xs.index(x2), xs.pop(i))
sort_insert_pop : 150.7 : 1.113629 # xs.insert(j, xs.pop(i))
I find that counterintuitive at first glance. When reading through the code
of sort_insert_index_pop()
, my first impression was puzzlement. In
particular, don't insert()
, index()
, and pop()
each imply linear
scans/shifts of the data? That seems bad, right? Moreover, having done the
enumerate benchmarks, I am not entirely convinced by an explanation based
solely on the general point that language operations implemented in C (such as
list.index()
) have a big speed advantage over the language operations
implemented directly in Python. Although that point is both true and important,
the enumerate benchmarks prove that in the general case, retrieving indexes via
xs.index(x)
is very slow. Out of the two forces -- the speed of the C-based
list
methods vs the inefficiency of those costly scans/shifts -- which one
has a larger magnitude within the context of the short-circuiting behavior of
insertion sort?
Summary of the tradeoffs. The table below tries to summarize the advantages and disadvantages of the two approaches. The insert-index-pop approach uses the fastest looping style in its inner loop, makes many fewer swaps, in a faster language -- but the swap itself is algorithmically inefficient. We know from the benchmarks how those tradeoffs weigh out in the end, but I cannot say with confidence that a survey of experienced Python engineers would have necessarily predicted this empirical outcome in advance -- and that is what we mean when we describe something as counterintuitive.
| classic-swap | insert-index-pop
-------------------------------------------------------
| |
Looping machinery | |
| |
- for x in xs | . | inner
- enumerate()/range() | outer | outer
- while COND | inner | .
| |
Swaps | |
| |
- Number | N * N / 2 | N
- Cost per swap | 1 | N * 1.5
- Language | Python | C
The code:
import argparse
import sys
from collections import namedtuple
from random import randint, shuffle
from time import time
####
# Benchmarking machinery.
####
# Groups of functions that we will benchmark.
FUNC_NAMES = {
'enumerate': [
# Just calling enumerate().
'enumerate_call_baseline', # it = None
'enumerate_call', # it = enumerate()
# Direct Python iteration.
'iterate_baseline', # for x in xs: pass
'iterate_with_index', # for x in xs: i += 1
'iterate_range_getitem', # for i in range(len(xs)): x = xs[i]
'iterate_get_index', # for x in xs: i = xs.index(x)
# Iteration with enumerate().
'enumerate_consume', # for x in it: pass
'enumerate_consume_unpack', # for i, x in it: pass
'enumerate_consume_getitem', # for x in it: x[0]
],
'sort': [
'sort_baseline', # xs.sort()
'sort_classic_swap', # classic index-based compare-swap
'sort_insert_index_pop', # xs.insert(xs.index(x2), xs.pop(i))
'sort_insert_pop', # xs.insert(j, xs.pop(i))
],
'check_sorts': [],
}
# Constants and simple data types.
STAT_NAMES = ('count', 'total', 'mean', 'min', 'max')
VALUE_NAMES = ('randint', 'random', 'shuffle', 'direct')
Stats = namedtuple('Stats', STAT_NAMES)
Result = namedtuple('Result', 'func stats')
def main(args):
# Parse command-line arguments.
ap = argparse.ArgumentParser()
ap.add_argument('scenario', choices = list(FUNC_NAMES))
ap.add_argument('--stat', default = 'total', choices = STAT_NAMES)
ap.add_argument('--count', type = int, default = 1000)
ap.add_argument('--trials', type = int, default = 100)
ap.add_argument('--values', default = 'randint', choices = VALUE_NAMES)
ap.add_argument('--presort', action = 'store_true')
opts = ap.parse_args(args)
# Generate some values.
xs = generate_values(opts.count, opts.values, opts.presort)
# Either sanity check to ensure than our sorts actually sort.
if opts.scenario == 'check_sorts':
exp = sorted(xs)
for fname in FUNC_NAMES['sort']:
ys = xs.copy()
f = globals()[fname]
f(ys)
print(ys == exp, fname)
# Or benchmark some functions.
else:
funcs = [globals()[fname] for fname in FUNC_NAMES[opts.scenario]]
results = measure_funcs(funcs, xs, opts.trials)
report = list(summarize(opts, results))
print('\n'.join(report))
def generate_values(count, mode, presort = False):
# Various ways of generating numbers to be sorted or enumerated.
if mode == 'randint':
xs = [randint(1, 1000) for _ in range(count)]
elif mode == 'random':
xs = [random() for _ in range(count)]
elif mode == 'shuffle':
xs = list(range(count))
shuffle(xs)
elif mode == 'direct':
xs = [int(x) for x in mode.split(',')]
return sorted(xs) if presort else xs
def measure_funcs(funcs, xs, trials):
# Benchmark several functions.
results = []
for f in funcs:
stats = measure(trials, f, xs)
r = Result(f, stats)
results.append(r)
return results
def measure(trials, func, xs):
# Benchmark one function.
times = []
for t in range(trials):
ys = xs.copy()
t0 = time()
func(ys)
t1 = time()
times.append(t1 - t0)
count = len(xs)
total = sum(times)
mean = total / len(times)
return Stats(count, total, mean, min(times), max(times))
def summarize(opts, results):
# Generate tabular output.
# Scenario header.
fmt = '\n# {} : stat={}, count={}, trials={}'
header = fmt.format(opts.scenario, opts.stat, opts.count, opts.trials)
yield header
# For the statistic we are analyzing, get its minimum value.
min_tup = min(results, key = lambda tup: tup[1])
min_val = getattr(min_tup[1], opts.stat)
# Print table for that statistic.
fmt = '{:<30} : {:8.1f} : {:.6f}'
for f, stats in results:
val = getattr(stats, opts.stat)
scaled_val = val / min_val
row = fmt.format(f.__name__, scaled_val, val)
yield row
####
# Benchmarking targets: enumerate() vs alternatives.
####
def enumerate_call_baseline(xs):
it = None
def enumerate_call(xs):
it = enumerate(xs)
def iterate_baseline(xs):
for x in xs:
pass
def iterate_with_index(xs):
i = 0
for x in xs:
i += 1
def iterate_range_getitem(xs):
for i in range(len(xs)):
x = xs[i]
def enumerate_consume(xs):
it = enumerate(xs)
for x in it:
pass
def enumerate_consume_getitem(xs):
it = enumerate(xs)
for x in it:
x[1]
def enumerate_consume_unpack(xs):
it = enumerate(xs)
for i, x in it:
pass
def iterate_get_index(xs):
for x in xs:
i = xs.index(x)
####
# Benchmarking targets: in-place insertion sorts.
####
def sort_baseline(xs):
xs.sort()
def sort_classic_swap(xs):
for i in range(1, len(xs)):
x = xs[i]
while i > 0 and xs[i - 1] > x:
xs[i] = xs[i - 1]
i -= 1
xs[i] = x
def sort_insert_pop(xs):
for i, x1 in enumerate(xs):
for j, x2 in enumerate(xs):
if x2 >= x1:
xs.insert(j, xs.pop(i))
break
def sort_insert_index_pop(xs):
for i, x1 in enumerate(xs):
for x2 in xs:
if x2 >= x1:
xs.insert(xs.index(x2), xs.pop(i))
break
if __name__ == '__main__':
main(sys.argv[1:])
-
2\$\begingroup\$ Something that is an order or two faster, are
pop
andinsert
. Those two combined are about 20 times faster than index alone, on this test. And that's for a list of ints, whose comparisons I believe are fairly fast. For more slowly comparing values,index
will become accordingly slower, whilepop
andinsert
aren't affected. \$\endgroup\$superb rain– superb rain2020年09月16日 22:54:46 +00:00Commented Sep 16, 2020 at 22:54
Most of the produced j
indexes won't be used, so that's wasteful. Turns out that searching the one that we do need is faster. This further reduced the time to 19 ms:
def sort(a):
for i, x in enumerate(a):
for y in a:
if y >= x:
a.insert(a.index(y), a.pop(i))
break
-
\$\begingroup\$ (Found this after posting the question and wasn't sure whether to edit it, and I think it works as a review.) \$\endgroup\$Manuel– Manuel2020年09月10日 14:08:06 +00:00Commented Sep 10, 2020 at 14:08
a.insert(a.index(y), a.pop(i))
to be very slow indeed. \$\endgroup\$