I was trying to write Sieve for the first time and I came up with this code:
def sieve(num):
numbers = set(range(3, num, 2))
for i in range(3, int(num**(1/2)) + 1 , 2):
num_set = set(i*c for c in range(3, num//2, 2))
numbers = numbers - num_set
return list(sorted((2, *numbers)))
The problem is that for num > 10**6
the time to create prime numbers increases.
Also, when I tried num = 10**8
my computer stopped working, started to make awkward noises and I had to restart it.
I think the problem is that I am dealing with sets. For large numbers (for instance, in num = 10**8
case) the set cannot be produced since my computer cannot process that much information, hence it stops.
Is there a way to solve this memory or time problem using my code or should use a different algorithm?
3 Answers 3
Using a set()
is your bottleneck, memory-wise.
>>> numbers = set(range(3, 10**8, 2))
>>> sys.getsizeof(numbers)
2147483872
>>> sys.getsizeof(numbers) + sum(map(sys.getsizeof, numbers))
3547483844
A set
of odd numbers up to 100 million is consuming (削除) 2GB (削除ここまで) 3.5GB (thank-you @ShadowRanger) of memory. When you do an operation like numbers = numbers - num_set
, you'll need to have 3 sets in memory at once. One for the original set, one for the set of numbers you are removing, and one for the resulting set. This will be greater than (削除) 4GB (削除ここまで) 7GB of memory, since some of the numbers you are removing aren't in the original set.
You don't need to realize the entire set of numbers you are removing in memory. You could simply remove the numbers from the set one at a time:
for c in range(3, num // 2, 2):
numbers.remove(i * c)
This is modifying the set in place, so the memory requirement will not exceed the initial 2GB of memory for the set.
Why are you looping c
over range(3, num // 2, 2)
? This is doing way too much work. The maximum value c
should obtain should satisfy i*c < num
, since no product i*c
larger than num
will be in the original set.
You should instead loop over range(3, num // i + 1, 2)
. This will decrease the size of the set of numbers you are removing as the prime numbers you find increase.
Why start removing primes at 3*i
? When i
is 97, you've already removed multiples of 3, 5, 7, 11, 13, 17, ... up to 89. The first multiple you need to remove is 97*97
. You would then continue with 99*97
, 101*97
, and so on, up to num
. So the range
should begin with i
, not 3
.
for c in range(i, num // i + 1, 2):
numbers.remove(i * c)
Actually, this is still too complicated. Let's get rid of the multiplication. This also greatly simplifies the upper limit of the range.
for multiple in range(i*i, num, 2*i):
numbers.remove(multiple)
Or equivalently, passing a generator to difference_update
to remove items in bulk, but without realizing the set of numbers to be removed in memory simultaneously.
numbers.difference_update(multiple for multiple in range(i*i, num, 2*i))
Even with all of the above changes, you still require 2GB of memory to compute the primes up to 100 million. And since a set
is unordered, you still have to sort
the surviving numbers afterwards to get your ordered list of primes.
A better way is to maintain an array of flags, one per candidate number. With 100 million candidate numbers, if each flag used only a single byte, you'd only require 100 MB of memory, a savings of a factor of 20. And since the array of flags is ordered, no sorting of the array would be required.
The bytearray
is one such structure. It is an array of bytes. You can store your candidates in the array as a 1
, and any non-primes (multiples of other primes) as 0
.
def sieve(num):
flags = bytearray(num) # Initially, all bytes are zero
flags[2] = 1 # Two is prime
for i in range(3, num, 2):
flags[i] = 1 # Odd numbers are prime candidates
# Find primes and eliminate multiples of those primes
for i in range(3, int(num ** 0.5) + 1, 2):
if flags[i]:
for multiple in range(i * i, num, 2 * i):
flags[multiple] = 0
return [ i for i, flag in enumerate(flags) if flag ]
Conserving a little bit more memory, you can store your list of primes in an array
import array
def sieve(num):
flags = bytearray(num) # Initially, all bytes are zero
flags[2] = 1 # Two is prime
for i in range(3, num, 2):
flags[i] = 1 # Odd numbers are prime candidates
# Find primes and eliminate multiples of those primes
for i in range(3, int(num ** 0.5) + 1, 2):
if flags[i]:
for multiple in range(i * i, num, 2 * i):
flags[multiple] = 0
return array.array('I', (i for i, flag in enumerate(flags) if flag))
For primes up to \10ドル^8\$, the array.array('I', ...)
stores the 5.7 million primes in a mere 23MB of memory. The list version takes a whopping 212MB.
Note: If you are using a 32-bit version of Python, you may need the type-code 'L'
instead of 'I'
to get storage for 4-byte integers in the array.
For the truly memory conscious, install the bitarray
module.
pip3 install bitarray
In addition to using only a single bit per flag, for 1/8th the memory usage in the sieve, it allows some truly fantastic slice assignments from a single boolean scalar, which makes clearing all multiples of a prime number into a simple single statement.
import array
from bitarray import bitarray
def sieve(num):
flags = bitarray(num)
flags.setall(False)
flags[2] = True # Two is prime
flags[3::2] = True # Odd numbers are prime candidates
for i in range(3, int(num ** 0.5) + 1, 2):
if flags[i]:
flags[i*i:num:2*i] = False # Eliminate multiples of this prime
primes = array.array('I', (i for i, flag in enumerate(flags) if flag))
return primes
Timings:
10^3: 0.000
10^4: 0.000
10^5: 0.004
10^6: 0.051
10^7: 0.428
10^8: 4.506
Note: Updated timing info. I just noticed I had for i in range(3, num + 1, 2)
in the last implementation instead of for i in range(3, int(num ** 0.5) + 1, 2)
, resulting in a lot of wasted time doing nothing.
Python 3.8 Update: Using math.isqrt(num)
is better than int(num ** 0.5)
:
for i in range(3, math.isqrt(num) + 1, 2):
-
2\$\begingroup\$ A note:
sys.getsizeof(numbers)
is only telling you the size of theset
structure itself, not all theint
s stored in it. The very smallestint
s are singletons, (CPython implementation detail), but you'd have to pay the memory cost of all the rest, so the memory used issys.getsizeof(numbers) + sum(map(sys.getsizeof, numbers))
, which on a 64 bit build of Python adds just shy of another 1.4 GB of memory onto the cost, starting at 28 bytes perint
for magnitudes of 30 bits and below, adding four more bytes for every additional 30 bits of magnitude or part thereof. \$\endgroup\$ShadowRanger– ShadowRanger2019年06月19日 03:43:52 +00:00Commented Jun 19, 2019 at 3:43 -
\$\begingroup\$ @ShadowRanger Excellent point. I did that for the list to get the 212MB, but neglected to do so for the
numbers
set. \$\endgroup\$AJNeufeld– AJNeufeld2019年06月19日 05:14:24 +00:00Commented Jun 19, 2019 at 5:14 -
\$\begingroup\$ Great answer. Exactly what I needed. I actually changed my code a bit after the Jackson's answer (that range parts). However it was against the rules so I couldn't change my code in the post. \$\endgroup\$camarman– camarman2019年06月19日 11:03:13 +00:00Commented Jun 19, 2019 at 11:03
-
2\$\begingroup\$ @Nzall
flags[start:stop:step]
references a "slice" of theflags
bit array, beginning with thestart
element, then thestart+step
element, then thestart+2*step
element, all the way up to (but not including) thestop
element. Since we start at a multiple of the current prime, and go up by a multiple of the prime, these elements all have indices which are multiples of the prime. The= False
assigns all of those elements in the bit array to false, which doesn’t remove those elements, but flags them as not prime candidates, so "removes" them from consideration. \$\endgroup\$AJNeufeld– AJNeufeld2019年06月19日 13:50:56 +00:00Commented Jun 19, 2019 at 13:50 -
1\$\begingroup\$ @AJNeufeld I see. I now also realize that you start at the square of the prime because every multiple of the prime before that has already been falsified by earlier prime falsifications. \$\endgroup\$Nzall– Nzall2019年06月19日 14:46:53 +00:00Commented Jun 19, 2019 at 14:46
Is there a way to solve this memory or time problem using my code or should use a different algorithm?
The algorithm is fine for the kind of scale you're talking about. It's the implementation of the algorithm which needs optimisation.
To tackle the memory issue, look at set
. Given that the elements of the set are integers from a fixed range and moderately dense in that range (about 1 in 18 numbers up to \10ドル^8\$ are prime) the ideal would be a data structure which uses 1 bit per number. (I'm not sure whether one is available in Python. In the worst case, since it has big integers you can use bit manipulations on numbers). But failing that, a simple array of Boolean values probably has less overhead than a set.
return list(sorted((2, *numbers)))
This is actually quite heavy-weight. It's probably not the bottleneck, but it might be worth asking yourself whether the caller needs a list. Perhaps you can use yield
inside the main loop and skip the post-processing altogether. Perhaps the memory pressure isn't so bad as to prevent you from accumulating the list in order. And if the memory pressure is a problem, perhaps you can break the set into pages, something like (warning: code not tested, and this doesn't include the other ideas I've mentioned):
primes = [2]
page_size = 1000000
for page_start in range(3, num, page_size):
page_end = min(num, page_start + page_size)
page = set(range(page_start, page_end, 2))
for p in primes:
remove multiples of p from page
for p in range(page_start, page_end, 2):
if p in page:
primes.append(p)
remove multiples of p from page
Note: I've thrown out several ideas. I understand that you're doing this as a learning exercise, and trying out various different directions should be useful for that purpose even if you conclude that there isn't enough benefit to compensate for the added complexity.
-
\$\begingroup\$ In general I dont need to turn set into list however when I need indexing, for instance, I will need to turn it to list. I dont know much anything about the bit manupilation... \$\endgroup\$camarman– camarman2019年06月18日 19:59:19 +00:00Commented Jun 18, 2019 at 19:59
-
\$\begingroup\$ My versions which uses a Boolean list can't work with \$n=10^8\$ with 32GB of memory. Chunking up the values is a good idea tho, probably could get it so it uses a list with vectorization pretty easily. \$\endgroup\$2019年06月19日 00:13:10 +00:00Commented Jun 19, 2019 at 0:13
I think your performance problems at 10**6 elements start here:
for i in range(3, int(num**(1/2)) + 1 , 2):
This generates a list of numbers that you then build a set of multiples from and remove those multiples from the numbers set. But this generates a set [3,5,7,9,11,13,15,...] so when you've removed all the multiples of 3 you still try and remove multiples of [9,15,21,...] all of which went when you removed the multiples of three.
In a classic implementation of sieve you would find the next smallest prime and remove that, then find the next smallest prime and remove that until you get to the square root of num.
For example with num = 25 :
- [], [3,5,7,9,11,13,15,17,19,21,23,25] -- remove multiples of 3
- [3], [5,7,11,13,17,19,23,25] - 5 is next lowest so remove its multiples
- [3,5], [7,11,13,17,19,23] - we've reached the square root of num, only primes left
So after each removal you want to find the new minimal element left in numbers but the problem with the set is that it's unordered so operation like min() is an O(N) operation, the entire set has to be scanned. You may be able to get round this by looking for an OrderedSet implementation, in which case each time you find a prime you remove it's multiples, remove the prime itself to a separate set say, and the next prime to remove is the minimal value in the numbers set.
As Peilonrayz points out in a comment, when you start to get toward 10*8 elements you need to think about how much memory these sets are going to need. You might well need a data structure that uses a lot less memory.
-
\$\begingroup\$ thanks my code now runs 8x faster. It can calculate up to
10**6
in 3.5 sec \$\endgroup\$camarman– camarman2019年06月18日 15:46:45 +00:00Commented Jun 18, 2019 at 15:46 -
1\$\begingroup\$ Depending on the Python version, this might or might not generate a
list
(the Python type). I'm pretty sure this is Python 3, which means that, no, memory isn't a concern with therange
statement, as it'll always be stored as just 3 ints in a special object, and the iterator will probably just store 1 int (the index). It's still a performance concern to hit all those unnecessary numbers, of course, just not a memory one. What might be causing memory issues is theset(...)
bit, which does store that whole thing. \$\endgroup\$anon– anon2019年06月18日 20:10:41 +00:00Commented Jun 18, 2019 at 20:10 -
\$\begingroup\$ @ I agree with you \$\endgroup\$camarman– camarman2019年06月18日 20:57:03 +00:00Commented Jun 18, 2019 at 20:57
Explore related questions
See similar questions with these tags.
//
implies 3 to me, but I'd like that confirmed. \$\endgroup\$*numbers
uses the unpacking operator, which is only in Python3 \$\endgroup\$