I was solving the "Find the Min" problem on Facebook Hacker Cup.
The code below works fine for the sample inputs given there, but for input size as big as 109, this takes hours to return the solution.
Problem statement:
After sending smileys, John decided to play with arrays. Did you know that hackers enjoy playing with arrays? John has a zero-based index array,
m
, which containsn
non-negative integers. However, only the firstk
values of the array are known to him, and he wants to figure out the rest.John knows the following: for each index
i
, wherek <= i < n
,m[i]
is the minimum non-negative integer which is not contained in the previous*k*
values ofm
.For example, if
k = 3
,n = 4
and the known values ofm
are[2, 3, 0]
, he can figure out thatm[3] = 1
.John is very busy making the world more open and connected, as such, he doesn't have time to figure out the rest of the array. It is your task to help him.
Given the first
k
values ofm
, calculate the nth value of this array. (i.e.m[n - 1]
).Because the values of
n
andk
can be very large, we use a pseudo-random number generator to calculate the firstk
values ofm
. Given positive integersa
,b
,c
andr
, the known values ofm
can be calculated as follows:m[0] = a m[i] = (b * m[i - 1] + c) % r, 0 < i < k
Input
The first line contains an integer T (T <= 20), the number of test cases.
This is followed by T test cases, consisting of 2 lines each.
The first line of each test case contains 2 space separated integers,
n
,k
(\1ドル <= k \le 10^5, k < n \le 10^9\$).The second line of each test case contains 4 space separated integers
a
,b
,c
,r
(\0ドル \le a, b, c \le 10^9, 1 \le r <= 10^9\$).
My solution:
import sys
cases=sys.stdin.readlines()
def func(line1,line2):
n,k=map(int,line1.split())
a,b,c,r =map(int,line2.split())
m=[None]*n
m[0]=a
for i in xrange(1,k):
m[i]= (b * m[i - 1] + c) % r
#print m
for j in range(0,n-k):
temp=set(m[j:k+j])
i=-1
while True:
i+=1
if i not in temp:
m[k+j]=i
break
return m[-1]
for ind,case in enumerate(xrange(1,len(cases),2)):
ans=func(cases[case],cases[case+1])
print "Case #{0}: {1}".format(ind+1,ans)
Sample input:
5 97 39 34 37 656 97 186 75 68 16 539 186 137 49 48 17 461 137 98 59 6 30 524 98 46 18 7 11 9 46
2 Answers 2
1. Improving your code
If you used
if __name__ == '__main__':
to guard the code that should be executed when your program is run as a script, then you'd be able to work on the code (for example, run timings) from the interactive interpreter.The name
func
is not very informative. And there's no docstring or doctests. What does this function do? What arguments does it take? What does it return?func
has many different tasks: it reads input, it generates pseudo-random numbers, and it computes the sequence in the problem. The code would be easier to read, test and maintain if you split these tasks into separate functions.Reading all the lines of a file into memory by calling the
readlines
method is usually not the best way to read a file in Python. It's better to read the lines one at a time by iterating over the file or usingnext
, if possible.When computing the result, you keep the whole sequence (all \$n\$ values) in memory. But this is not necessary: you only need to keep the last \$k\$ values (the ones that you use to compute the minimum excluded number, or mex). It will be convenient to use a
collections.deque
for this.Similarly, it's not necessary to build the set of the last \$k\$ values every time you want to compute the next minimum excluded number. If you kept this set around, then at each stage, all you'd have to do is to add one new value to the set and remove up to one old value. It will be convenient to use a
collections.Counter
for this.
Applying all of these improvements results in the code shown below.
import sys
from itertools import count, islice
from collections import Counter, deque
def pseudorandom(a, b, c, r):
"""Generate pseudo-random numbers starting with a and proceeding
according to the linear congruential recurrence
a -> (b * a + c) % r.
>>> list(islice(pseudorandom(34, 37, 656, 97), 10))
[34, 71, 82, 4, 28, 43, 16, 84, 78, 50]
"""
while True:
yield a
a = (b * a + c) % r
def mex(c):
"""Return the "mex" (Minimum EXcluded) number in the Counter c.
>>> mex(Counter(range(10)))
10
>>> mex(Counter([2, 3, 0]))
1
"""
for i in count():
if c[i] == 0:
return i
def iter_mex(k, it):
"""Generate the sequence whose first k values are given by the
iterable `it`, and whose ith value (for i > k) is the "mex"
(minimum excluded number) of the previous k values of the
sequence.
"""
q = deque(islice(it, k))
for m in q:
yield m
c = Counter(q)
while True:
m = mex(c)
yield m
q.append(m)
c[m] += 1
c[q.popleft()] -= 1
def nth_iter_mex(n, k, a, b, c, r):
"""Return the nth value of the sequence whose first k values are
given by pseudorandom(a, b, c, r) and whose values thereafter
are the mex of the previous k values of the sequence.
"""
seq = iter_mex(k, pseudorandom(a, b, c, r))
return next(islice(seq, n, None))
def main(f):
cases = int(next(f))
for i in xrange(cases):
n, k = map(int, next(f).split())
a, b, c, r = map(int, next(f).split())
print('Case #{}: {}'.format(i, nth_iter_mex(n - 1, k, a, b, c, r)))
if __name__ == '__main__':
main(sys.stdin)
2. Timing
Python comes with the timeit
module for timing code. Let's have a look at how long the program takes on some medium-sized test instances:
>>> from timeit import timeit
>>> timeit(lambda:nth_iter_mex(10**5, 10**2, 34, 37, 656, 97), number=1)
1.6787691116333008
>>> timeit(lambda:nth_iter_mex(10**5, 10**3, 34, 37, 656, 97), number=1)
16.67353105545044
>>> timeit(lambda:nth_iter_mex(10**5, 10**4, 34, 37, 656, 97), number=1)
143.31502509117126
The runtime of the program is approximately proportional to both \$n\$ and \$k\,ドル so extrapolating from the above timings, I expect that in the worst case, when \$n = 10^9\$ and \$k = 10^5\,ドル the computation will take about five months. So we've got quite a lot of improvement to make!
3. Speeding up the mex-finding
In the implementation above it takes \$Θ(k)\$ time to find the mex of the last \$k\$ numbers, which means that the whole runtime is \$Θ(nk)\$. We can improve the mex finding to \$O(\log k)\$ as follows.
First, note that the mex of the last \$k\$ numbers is always a number between \0ドル\$ and \$k\$ inclusive. So if we keep track of the set of excluded numbers in this range (that is, the numbers between \0ドル\$ and \$k\$ inclusive that do not appear in the last \$k\$ elements of the sequence), then the mex is the smallest number in this set. And if we keep this set of excluded numbers in a heap then we can find the smallest in \$O(\log k)\$.
Like this:
import heapq
def iter_mex(k, it):
"""Generate the sequence whose first k values are given by the
iterable `it`, and whose ith value (for i > k) is the "mex"
(minimum excluded number) of the previous k values of the
sequence.
"""
q = deque(islice(it, k))
for m in q:
yield m
excluded = list(set(xrange(k + 1)).difference(q))
heapq.heapify(excluded)
c = Counter(q)
while True:
mex = heapq.heappop(excluded)
yield mex
q.append(mex)
c[mex] += 1
old = q.popleft()
c[old] -= 1
if c[old] == 0:
heapq.heappush(excluded, old)
Now the timings are much more satisfactory:
>>> timeit(lambda:nth_iter_mex(10**5, 10**2, 34, 37, 656, 97), number=1)
0.2606780529022217
>>> timeit(lambda:nth_iter_mex(10**5, 10**3, 34, 37, 656, 97), number=1)
0.2634279727935791
>>> timeit(lambda:nth_iter_mex(10**5, 10**4, 34, 37, 656, 97), number=1)
0.32929110527038574
>>> timeit(lambda:nth_iter_mex(10**6, 10**5, 34, 37, 656, 97), number=1)
3.5652129650115967
However, we still expect that when \$n = 10^9\$ and \$k = 10^5\,ドル the computation will take around an hour, which is still far too long.
4. A better algorithm
Let's have a look at the actual numbers in the sequence generated by iter_mex
and see if there's a clue as to a better way to calculate them. Here are the first hundred values from a sequence with \$k = 13\$ and \$r = 23\$:
>>> list(islice(iter_mex(13, pseudorandom(34, 37, 656, 23)), 100))
[34, 5, 13, 10, 14, 1, 3, 8, 9, 0, 12, 19, 2, 4, 6, 5, 7, 10, 11, 1,
3, 8, 9, 0, 12, 13, 2, 4, 6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13,
2, 4, 6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4, 6, 5, 7, 10,
11, 1, 3, 8, 9, 0, 12, 13, 2, 4, 6, 5, 7, 10, 11, 1, 3, 8, 9, 0,
12, 13, 2, 4, 6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4, 6, 5]
You should have noticed a pattern there, but let's lay it out in rows of length \$k + 1 = 14\$ to make it obvious:
[34, 5, 13, 10, 14, 1, 3, 8, 9, 0, 12, 19, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5, 7, 10, 11, 1, 3, 8, 9, 0, 12, 13, 2, 4,
6, 5]
You can see that the first \$k + 1\$ values are different, but thereafter each group of \$k + 1\$ items from the sequence is identical (and moreover is a permutation of the numbers \0,ドル \ldots, k\$).
Let's check that this happens for some bigger examples:
>>> it = iter_mex(10**5, pseudorandom(34, 37, 656, 4294967291))
>>> next(islice(it, 10**5, None))
0
>>> s, t = [list(islice(it, 10**5 + 1)) for _ in xrange(2)]
>>> s == t
True
>>> sorted(s) == range(10**5 + 1)
True
Let's prove that this always happens. First, as noted above, the mex of \$k\$ numbers is always a number between \0ドル\$ and \$k\$ inclusive, so all numbers in the sequence after the first \$k\$ lie in this range. Each number in the sequence is different from the previous \$k\,ドル so each group of \$k + 1\$ numbers after the first \$k\$ must be a permutation of the numbers from \0ドル\$ to \$k\$ inclusive. So if the (\$k + 1\$)th last number is \$j\,ドル then the last \$k\$ numbers will be a permutation of the numbers from \0ドル\$ to \$k\$ inclusive, except for \$j\$. So \$j\$ is their mex, and that will be the next number. Hence the pattern repeats.
So now it's clear how to solve the problem. If \$n > k\,ドル element number \$n\$ in the sequence is the same as element number \$k + 1 + n \bmod (k + 1)\$.
The implementation is straightforward:
def nth_iter_mex(n, k, a, b, c, r):
"""Return the nth value of the sequence whose first k values are
given by pseudorandom(a, b, c, r) and whose values thereafter
are the mex of the previous k values of the sequence.
"""
seq = iter_mex(k, pseudorandom(a, b, c, r))
if n > k: n = k + 1 + n % (k + 1)
return next(islice(seq, n, None))
and now the timings are acceptable:
>>> timeit(lambda:nth_iter_mex(10**9, 10**5, 34, 37, 656, 4294967291), number=1)
1.6802959442138672
Untested, but using Counter
from collections
may be quicker than forming a set of the last k
values each time.
counter = collections.Counter(m)
for j in xrange(n - k):
i = 0
while counter[i]:
i += 1
counter[m[j]] -= 1
counter[i] = 1
m.append(i)
If making the counter
takes a long time because k
is very large, you could consider making it in 'chunks', reading say the smallest 100 values from m
initially then reading another 100 only when i
gets larger than the smallest 100.
If that's still not fast enough you could try using deque
also from collections
, and manually updating the set in each cycle of the loop. You can also take advantage of the fact that the number removed from the top of the list on each cycle gives you a clue as to what the next number has to be (higher or lower); and can use a slightly simplified algorithm when dealing with numbers that the programme has added to the list (in which any sequence of k
values can contain no duplicates), as opposed to the original pseudo-random list (which may). (Again, the following is untested)
from collections import deque
def get_next(i):
while i in set_last_k:
i += 1
return i
for line1, line2 in zip(cases[1::2], cases[2::2]):
n, k = map(int, line1.split())
a, b, c, r = map(int, line2.split())
m = [a]
for i in xrange(k - 1):
m.append((b * m[-1] + c) % r)
last_k = deque(m)
set_last_k = set(last_k)
next = get_next(0)
for j in xrange(min(k, n - k)): # original list - may contain duplicates
i = next
removed = last_k.popleft()
if removed in last_k:
next = get_next(i+1)
else:
set_last_k.remove(removed)
if removed < i:
next = removed
else:
next = get_next(i+1)
m.append(i)
last_k.append(i)
set_last_k.add(i)
if n > 2*k:
for j in xrange(n - 2*k): # extended list - no duplicates
i = next
removed = last_k.popleft()
set_last_k.remove(removed)
if removed < i:
next = removed
else:
next = get_next(i + 1)
m.append(i)
last_k.append(i)
set_last_k.add(i)
print len(m), m[-1]
Explore related questions
See similar questions with these tags.