I implemented (a refinement of) the Sieve of Eratosthenes for counting primes less than a given number n. This is a coding exercise from LeetCode. The class Solution
is of course not necessary but only for the purpose of submitting the code on LeetCode.
The code passes the tests when n is not very large, but "Time Limit Exceeded" when the input is 858232. I am looking for ways to speed up the code.
class Solution(object):
def countPrimes(self, n):
"""
:type n: int
:rtype: int
"""
# Sieve of Eratosthenes
if n <= 2:
return 0
sieve = [False, False] + [True] * (n - 2)
for p in range(2, int(sqrt(n)) + 1):
if sieve[p]:
for x in range(p * p, n, p):
sieve[x] = False
return sum(sieve)
-
\$\begingroup\$ FYI I submitted your solution a couple of times on LeetCode and it seems to succeed. It says 83.93% faster than other Python3 solutions. \$\endgroup\$Marc– Marc2022年07月08日 11:11:20 +00:00Commented Jul 8, 2022 at 11:11
-
\$\begingroup\$ @Marc: thanks for that. It is indeed strange that when submitting several times, the code passes all the tests in some submission. In the Time-Limit-Exceeded one, it complains with the mentioned input 858232. \$\endgroup\$anon– anon2022年07月08日 13:43:20 +00:00Commented Jul 8, 2022 at 13:43
2 Answers 2
The math
library also has the isqrt(n)
function, that calculates the integer sqrt. Use it instead of int(sqrt(n))
.
The most obvious way to speed up your solution is to stop testing all the even numbers.
This checks all the numbers >= 2:
for p in range(2, int(sqrt(n)) + 1):
if sieve[p]:
But the first step eliminates all even numbers > 2. So you need to test them.
# handle 2 separately
for x in range(2 * 2, n, 2):
sieve[x] = False
# now just check the odd numbers
for p in range(3, isqrt(n) + 1, 2):
if sieve[p]:
# p*p + p is even, so we can skip it
for x in range(p * p, n, 2*p):
sieve[x] = False
Of course, you could just initialize the sieve
with all the even numbers > 2 set to False:
sieve = [False, True] * ((n + 1) // 2)
sieve[1] = False
sieve[2] = True
if `n` is odd, then sieve is too long
if n&1: sieve.pop()
All together, it runs twice as fast for n = 858233
:
def countPrimes(n):
"""
:type n: int
:rtype: int
"""
# Sieve of Eratosthenes
if n <= 2:
return 0
sieve = [False, True] * ((n + 1) // 2)
sieve[1] = False
sieve[2] = True
if n&1: sieve.pop()
for p in range(3, isqrt(len(sieve))+1 , 2):
if sieve[p]:
for x in range(p * p, n, 2*p):
sieve[x] = False
return sum(sieve)
Of course, in addition to considering all even numbers > 2 set to False, multiples of 2 and 3 can be excluded to further speed up
def countPrimes(n):
if n<=2:
return 0
if n<6:
return n//3+n//4
dimv=n//6
sieve5mod6 = [True] * (dimv+1)
sieve1mod6 = [True] * (dimv+1)
imax=int((n**0.5)/6)+1
for i in range(1,imax+1):
if sieve5mod6[i]:
imin=6*i*i
pi=6*i-1
sieve5mod6[imin::pi]=[False]*((dimv-imin)//pi+1)
sieve1mod6[imin-2*i::pi]=[False]*((dimv-imin+2*i)//pi+1)
if sieve1mod6[i]:
imin=6*i*i
pi=6*i+1
sieve5mod6[imin::pi]=[False]*((dimv-imin)//pi+1)
sieve1mod6[imin+2*i::pi]=[False]*((dimv-imin-2*i)//pi+1)
#since sieve5m6[0]=True and sieve1m6[0]=True is exploited to consider 2 and 3 in the count
#for an exact result it is necessary to distinguish the different values ??of n%6
if n%6>1: return sum(sieve5mod6)+sum(sieve1mod6)
return sum(sieve5mod6)+sum(sieve1mod6[:-1:])