The problem statement:
An arithmetic progression is a sequence of the form a, a+b, a+2b, ..., a+nb where n=0, 1, 2, 3, ... . For this problem, a is a non-negative integer and b is a positive integer.
Write a program that finds all arithmetic progressions of length n in the set S of bisquares. The set of bisquares is defined as the set of all integers of the form p2 + q2 (where p and q are non-negative integers).
The two lines of input are n and m, which are the length of each sequence, and the upper bound to limit the search of the bi squares respectively.
My code solves the problem correctly, however it is too slow for the 5 second time constraint. Is there any way to further optimize my code?
import heapq
with open('ariprog.in', 'r') as fin, open('ariprog.out', 'w') as fout:
length = int(fin.readline().strip())
pqMax = int(fin.readline().strip())
def isValid(start, diff):
for i in range(1,length):
if start + (diff*i) not in bisquares:
return False
return True
bisquares = set()
for p in range(pqMax+1):
for q in range(pqMax+1):
bisquares.add(p**2 + q**2)
maxBS = max(bisquares) # do the calculation here so we only have to do it once
res=[]
for start in bisquares:
for diff in range(1, (maxBS-start)//(length-1)+1):
if start + (diff*(length-1)) > maxBS:
break
if isValid(start, diff):
heapq.heappush(res, (diff, start)) # diff goes first in the tuple so it sorts by smallest diff first
if res: # if res is non-empty
while res:
diff, start = heapq.heappop(res)
fout.write(f"{start} {diff}\n")
else:
fout.write("NONE\n")
1 Answer 1
Some tiny things:
- Instead of using power
p**2 + q**2
, use multiplicationp*p + q*q
. - When generating
bisquares
, useq in range(p, pqMax+1)
to avoid duplicating work such as1^2 + 2^2 = 2^2 + 1^2
.
Most importantly, you can achieve a visible speedup with modulo arithmetic. I was inspired to do this after reading some output and because the sample code doesn't really use any properties of bisquares. I use brute force to show that arithmetic progressions reaching minimum lengths (4, 6, 14) must have differences that are multiples of interesting numbers (4, 12, 84).
import math
def evaluate_residues(mod_num, min_len):
unique_residues = set()
for p in range(mod_num):
for q in range(p, mod_num):
start = p * p + q * q
unique_residues.add(start % mod_num)
common_diff = 0
for start in unique_residues:
for diff in range(1, mod_num):
valid = True
for i in range(min_len):
if ((start + i * diff) % mod_num) not in unique_residues:
valid = False
break
if valid:
common_diff = math.gcd(common_diff, diff)
print(common_diff)
evaluate_residues(8, 4) # yields 4
evaluate_residues(8 * 9, 6) # yields 12
evaluate_residues(8 * 9 * 49, 14) # yields 84
These results can be summarized as the step size of the diff for
loop:
step = 1
if length >= 4:
step = 8
if length >= 6:
step = 12
if length >= 14:
step = 84
# ...
for diff in range(step, (maxBS - start) // (length - 1) + 1, step):
When all of these things are put together, it reduces the time from >58 seconds to <1 second when length is 16 and pqMax is 250. When length is 14 and pqMax is 250, it went from >66 seconds to <6 seconds. It may be necessary to search through the bisquares more efficiently. Overall, I think your solution is well-written and I couldn't think of any ways to improve on it in terms of time complexity.
Explore related questions
See similar questions with these tags.
ariprog.in
contains two lines, the first beingn
and the second being the largest thatp
andq
can be. Is that correct? Do you have a bound on how big those numbers can be? \$\endgroup\$