I implemented (a refinement of) the Sieve of Eratosthenes for counting primes less than a given number n. This is a coding exercise from LeetCode. The class Solution is of course not necessary but only for the purpose of submitting the code on LeetCode.
The code passes the tests when n is not very large, but "Time Limit Exceeded" when the input is 858232. I am looking for ways to speed up the code.
class Solution(object):
def countPrimes(self, n):
"""
:type n: int
:rtype: int
"""
# Sieve of Eratosthenes
if n <= 2:
return 0
sieve = [False, False] + [True] * (n - 2)
for p in range(2, int(sqrt(n)) + 1):
if sieve[p]:
for x in range(p * p, n, p):
sieve[x] = False
return sum(sieve)
-
\$\begingroup\$ FYI I submitted your solution a couple of times on LeetCode and it seems to succeed. It says 83.93% faster than other Python3 solutions. \$\endgroup\$Marc– Marc2022年07月08日 11:11:20 +00:00Commented Jul 8, 2022 at 11:11
-
\$\begingroup\$ @Marc: thanks for that. It is indeed strange that when submitting several times, the code passes all the tests in some submission. In the Time-Limit-Exceeded one, it complains with the mentioned input 858232. \$\endgroup\$anon– anon2022年07月08日 13:43:20 +00:00Commented Jul 8, 2022 at 13:43
2 Answers 2
The math library also has the isqrt(n) function, that calculates the integer sqrt. Use it instead of int(sqrt(n)).
The most obvious way to speed up your solution is to stop testing all the even numbers.
This checks all the numbers >= 2:
for p in range(2, int(sqrt(n)) + 1):
if sieve[p]:
But the first step eliminates all even numbers > 2. So you need to test them.
# handle 2 separately
for x in range(2 * 2, n, 2):
sieve[x] = False
# now just check the odd numbers
for p in range(3, isqrt(n) + 1, 2):
if sieve[p]:
# p*p + p is even, so we can skip it
for x in range(p * p, n, 2*p):
sieve[x] = False
Of course, you could just initialize the sieve with all the even numbers > 2 set to False:
sieve = [False, True] * ((n + 1) // 2)
sieve[1] = False
sieve[2] = True
if `n` is odd, then sieve is too long
if n&1: sieve.pop()
All together, it runs twice as fast for n = 858233:
def countPrimes(n):
"""
:type n: int
:rtype: int
"""
# Sieve of Eratosthenes
if n <= 2:
return 0
sieve = [False, True] * ((n + 1) // 2)
sieve[1] = False
sieve[2] = True
if n&1: sieve.pop()
for p in range(3, isqrt(len(sieve))+1 , 2):
if sieve[p]:
for x in range(p * p, n, 2*p):
sieve[x] = False
return sum(sieve)
Of course, in addition to considering all even numbers > 2 set to False, multiples of 2 and 3 can be excluded to further speed up
def countPrimes(n):
if n<=2:
return 0
if n<6:
return n//3+n//4
dimv=n//6
sieve5mod6 = [True] * (dimv+1)
sieve1mod6 = [True] * (dimv+1)
imax=int((n**0.5)/6)+1
for i in range(1,imax+1):
if sieve5mod6[i]:
imin=6*i*i
pi=6*i-1
sieve5mod6[imin::pi]=[False]*((dimv-imin)//pi+1)
sieve1mod6[imin-2*i::pi]=[False]*((dimv-imin+2*i)//pi+1)
if sieve1mod6[i]:
imin=6*i*i
pi=6*i+1
sieve5mod6[imin::pi]=[False]*((dimv-imin)//pi+1)
sieve1mod6[imin+2*i::pi]=[False]*((dimv-imin-2*i)//pi+1)
#since sieve5m6[0]=True and sieve1m6[0]=True is exploited to consider 2 and 3 in the count
#for an exact result it is necessary to distinguish the different values ??of n%6
if n%6>1: return sum(sieve5mod6)+sum(sieve1mod6)
return sum(sieve5mod6)+sum(sieve1mod6[:-1:])