Given an array:
arr = [2,3,1,4]
I could write a count inversion array such that counting all numbers n2
after a certain number n1
in arr
such that n1
> n2
and write it like this
[1 1 0 0]
Similarly, the inversion array of:
[2, 1, 4, 3]
would be:
[1, 0, 1, 0]
For:
[20]
[1, 2, 3, 4, 5, 6]
[87, 78, 16, 94]
Output would be:
0
0 0 0 0 0 0
2 1 0 0
Constraints:
- \1ドル \le N \le 10^4\$
- \1ドル \le i \le 10^6\$
The code that I wrote works for most cases but takes>10sec for an extra large number of test cases.
from copy import copy
def merge(arr, left_lo, left_hi, right_lo, right_hi, dct):
startL = left_lo
startR = right_lo
N = left_hi-left_lo + 1 + right_hi - right_lo + 1
aux = [0] * N
res = []
for i in xrange(N):
if startL > left_hi:
aux[i] = arr[startR]
startR += 1
elif startR > right_hi:
aux[i] = arr[startL]
startL += 1
elif arr[startL] <= arr[startR]:
aux[i] = arr[startL]
startL += 1
# print aux
else:
aux[i] = arr[startR]
res.append(startL)
startR += 1
# print aux
for index in res:
for x in xrange(index, left_hi+1):
dct[arr[x]] += 1
for i in xrange(left_lo, right_hi+1):
arr[i] = aux[i - left_lo]
return
def merge_sort(arr, lo, hi, dct):
mid = (lo+hi)/2
if lo<=mid<hi:
merge_sort(arr, lo, mid, dct)
merge_sort(arr, mid+1, hi, dct)
merge(arr, lo, mid, mid+1, hi, dct)
return
def count_inversion(arr, N):
lo = 0
hi = N-1
dct = {i:0 for i in arr}
arr2 = copy(arr)
merge_sort(arr, lo, hi, dct)
return ' '.join([str(dct[num]) for num in arr2])
count_inversion
calls merge_sort
and that's where the total number of LEFT> RIGHT inversions are incremented. All numbers are stored in a dictionary with counts such that whenever L> R occurs all numbers in the left array starting from L to end of Left array are incremented by 1.
Now I understand there could be a way to optimize this snippet:
for index in res:
for x in xrange(index, left_hi+1):
dct[arr[x]] += 1
1 Answer 1
Thanks for an interesting, well-posed question. Before we get to issues of performance, let me make some other suggestions for your code.
The way you have broken down the code into three functions is reasonable and logical. Nice!
The functions you wrote don't have docstrings, so it is hard to know how to use them. I couldn't get your
count_inversion()
function to run initially, for example, because I didn't know whatN
was supposed to be. Adding docstrings would make this clear.For the specific case of the
N
parameter incount_inversion
, why do you need it? When I used the function I did it like this:arr = [2, 3, 1, 4] arr2 = [2, 1, 4, 3] arr3 = [20] arr4 = [1, 2, 3, 4, 5, 6] arr5 = [87, 78, 16, 94] arrs_to_test = [arr, arr2, arr3, arr4, arr5] [count_inversion(test, len(test)) for test in arrs_to_test]
That to me suggests that you don't need
N
as a parameter, instead just do something like:def count_inversion(arr): # docstring goes here N = len(arr) # <<rest of code>>
I usually dislike when people are sticklers for PEP8 variable-naming conventions in mathematically oriented code, but I think your naming could use some work. For example I had to read
dct = {i:0 for i in arr}
multiple times to understand thati
was not an index but the data. Sodct = {el:0 for el in arr}
would have been more natural to me. Plusdct
isn't the best name either. If I understand the code correctly, perhapsresult
would be better?Possible bug: Related to the above, do you really want to create a dictionary keyed on the elements of
arr
? Doing so means that the behavior when integers are repeated in the input is probably not what you want:>>> repeated = [3, 2, 1, 0, 3, 4, 5] >>> count_inversion(repeated, len(repeated)) '3 2 1 0 3 0 0'
Is the fifth element in this array really supposed to be "3"?
Right now the code functions because of the mutability of
dct
. I'm not 100% sure it could be done cleanly, but if possible I'd suggest re-writingmerge_sort
andmerge
to returndct
instead of returingNone
. That way the initialization ofdct
could happen in those functions too, which feels more natural to me.You are using your (mysteriously named)
dct
variable in a way that looks like aCounter
, so you might consider using this built-in datatype of Python.Now, to issues of performance. In Python, line-profiler is an easy to use package for assessing the performance of your code. I use this package in a Jupyter notebook like this:
from random import randint %load_ext line_profiler big_test = [randint(0, 10000) for _ in range(1000)] %lprun -f count_inversion -f merge_sort -f merge count_inversion(big_test, len(big_test))
The output of this operation is at the end of my answer. It shows that you are right, and the slowest part of the operation is indeed the nested
for
loops inmerge()
.Dictionaries, unlike lists, can't be mutated using slice notation. The only reason you need the
for x in xrange(index, left_hi+1):
nested loop is because you can't change whole swaths of the dictionary at once. With lists, you can. Thus, if you agree that the possible bug I described above is in fact a bug, you can switch to storing the output values in a list instead of a dictionary, and get rid of the nested loop. The speedup is very small for short input arrays but grows with array size. On my machine it led to a ~2-10× improvement at input arrays of 10,000 elements. I'll put the output of my line profiling below too.Using numpy for applications like this makes sense because python doesn't have a mutable fixed-type data structure. Interesting, using numpy only to represent the
dct
variable in your original code, while leaving in place all otherfor
loops, speeds up the execution another 2× or so for lists of 10,000 elements.for index in res: sublist_length = left_hi+1 - index out[index:left_hi+1] += np.ones(sublist_length, dtype = int)
Original code timing
from random import randint
big_test = [randint(0, 100) for _ in range(10000)]
%lprun -f merge count_inversion(copy(big_test), len(big_test))
Results in:
Timer unit: 1e-06 s
Total time: 26.9413 s
File: <ipython-input-1-77a541281305>
Function: merge at line 4
Line # Hits Time Per Hit % Time Line Contents
==============================================================
4 def merge(arr, left_lo, left_hi, right_lo, right_hi, dct):
5 9999 4407 0.4 0.0 startL = left_lo
6 9999 4315 0.4 0.0 startR = right_lo
7 9999 5812 0.6 0.0 N = left_hi-left_lo + 1 + right_hi - right_lo + 1
8 9999 8330 0.8 0.0 aux = [0] * N
9 9999 4401 0.4 0.0 res = []
10 143615 72509 0.5 0.3 for i in range(N):
11
12 133616 60250 0.5 0.2 if startL > left_hi:
13 5778 2956 0.5 0.0 aux[i] = arr[startR]
14 5778 2700 0.5 0.0 startR += 1
15 127838 57380 0.4 0.2 elif startR > right_hi:
16 7503 3822 0.5 0.0 aux[i] = arr[startL]
17 7503 3507 0.5 0.0 startL += 1
18 120335 64095 0.5 0.2 elif arr[startL] <= arr[startR]:
19 61505 31935 0.5 0.1 aux[i] = arr[startL]
20 61505 28853 0.5 0.1 startL += 1
21 # print aux
22 else:
23 58830 30780 0.5 0.1 aux[i] = arr[startR]
24 58830 34221 0.6 0.1 res.append(startL)
25 58830 28077 0.5 0.1 startR += 1
26 # print aux
27
28 68829 35529 0.5 0.1 for index in res:
29 24750922 11838525 0.5 43.9 for x in range(index, left_hi+1):
30 24692092 14464132 0.6 53.7 dct[arr[x]] += 1
31
32 143615 76619 0.5 0.3 for i in range(left_lo, right_hi+1):
33 133616 73794 0.6 0.3 arr[i] = aux[i - left_lo]
34 9999 4362 0.4 0.0 return
Improved Code Timing (list slicing)
Timer unit: 1e-06 s
Total time: 3.11468 s
File: <ipython-input-2-224c772db490>
Function: new_merge at line 4
Line # Hits Time Per Hit % Time Line Contents
==============================================================
4 def new_merge(arr, left_lo, left_hi, right_lo, right_hi, out):
5 # docstring goes here
6 9999 4566 0.5 0.1 startL = left_lo
7 9999 4491 0.4 0.1 startR = right_lo
8 9999 6601 0.7 0.2 N = left_hi-left_lo + 1 + right_hi - right_lo + 1
9 9999 7904 0.8 0.3 aux = [0] * N
10 9999 4571 0.5 0.1 res = []
11 143615 69607 0.5 2.2 for i in xrange(N):
12
13 133616 62294 0.5 2.0 if startL > left_hi:
14 5778 3044 0.5 0.1 aux[i] = arr[startR]
15 5778 2804 0.5 0.1 startR += 1
16 127838 59349 0.5 1.9 elif startR > right_hi:
17 7503 3945 0.5 0.1 aux[i] = arr[startL]
18 7503 3673 0.5 0.1 startL += 1
19 120335 67720 0.6 2.2 elif arr[startL] <= arr[startR]:
20 61505 33280 0.5 1.1 aux[i] = arr[startL]
21 61505 30419 0.5 1.0 startL += 1
22 # print aux
23 else:
24 58830 31818 0.5 1.0 aux[i] = arr[startR]
25 58830 34578 0.6 1.1 res.append(startL)
26 58830 28988 0.5 0.9 startR += 1
27 # print aux
28
29 68829 37996 0.6 1.2 for index in res:
30 58830 31690 0.5 1.0 sublist_length = left_hi+1 - index
31 58830 164834 2.8 5.3 ones = [1]*sublist_length
32 58830 2264227 38.5 72.7 out[index:left_hi+1] = map(add, out[index:left_hi+1], ones)
33
34 143615 73228 0.5 2.4 for i in xrange(left_lo, right_hi+1):
35 133616 78612 0.6 2.5 arr[i] = aux[i - left_lo]
36 9999 4440 0.4 0.1 return
Improved Code Timing (numpy)
Timer unit: 1e-06 s
Total time: 0.979072 s
File: <ipython-input-33-e2db83e49c93>
Function: d_merge at line 4
Line # Hits Time Per Hit % Time Line Contents
==============================================================
4 def d_merge(arr, left_lo, left_hi, right_lo, right_hi, out):
5 # docstring goes here
6 9999 4439 0.4 0.5 startL = left_lo
7 9999 4360 0.4 0.4 startR = right_lo
8 9999 5728 0.6 0.6 N = left_hi-left_lo + 1 + right_hi - right_lo + 1
9 9999 8283 0.8 0.8 aux = [0] * N
10 9999 4511 0.5 0.5 res = []
11 143615 69717 0.5 7.1 for i in xrange(N):
12
13 133616 62436 0.5 6.4 if startL > left_hi:
14 5778 2955 0.5 0.3 aux[i] = arr[startR]
15 5778 2779 0.5 0.3 startR += 1
16 127838 59843 0.5 6.1 elif startR > right_hi:
17 7503 3916 0.5 0.4 aux[i] = arr[startL]
18 7503 3645 0.5 0.4 startL += 1
19 120335 66024 0.5 6.7 elif arr[startL] <= arr[startR]:
20 61505 32834 0.5 3.4 aux[i] = arr[startL]
21 61505 29695 0.5 3.0 startL += 1
22 # print aux
23 else:
24 58830 31913 0.5 3.3 aux[i] = arr[startR]
25 58830 34546 0.6 3.5 res.append(startL)
26 58830 28552 0.5 2.9 startR += 1
27 # print aux
28
29 68829 36887 0.5 3.8 for index in res:
30 58830 31918 0.5 3.3 sublist_length = left_hi+1 - index
31 58830 303818 5.2 31.0 out[index:left_hi+1] += np.ones(sublist_length, dtype = int)
32
33 143615 72165 0.5 7.4 for i in xrange(left_lo, right_hi+1):
34 133616 73672 0.6 7.5 arr[i] = aux[i - left_lo]
35 9999 4436 0.4 0.5 return
Note:
All the code I used (both mine and the original) is in a Jupyter notebook hosted at GitHub.
-
1\$\begingroup\$ Thanks, where I used the code, i wasnt sure np was available. it is ofcourse an improvement.also for tackling edge cases as you noted. \$\endgroup\$user2290820– user22908202016年05月23日 13:33:47 +00:00Commented May 23, 2016 at 13:33
/
should be//
inmerge_sort()
. Otherwise: RuntimeError: maximum recursion depth exceeded. \$\endgroup\$