I wrote a Python program that draws the Mandelbrot set fractal. However, the program is very slow. I would appreciate any feedback on the program, but I would especially appreciate feedback on how to improve the performance of my program.
from PIL import Image, ImageDraw
import numpy as np
def FractalFunc(x, c):
return (x * x) + c
def IterateFractal(c, iterateMax=100):
z = 0
last = np.nan
for i in range(iterateMax):
# c = FractalFunc(c, (.05+.5j)) # in case I want to create the Julia set
z = FractalFunc(z,c)
# if (np.abs(c) > 2):
if (np.abs(z) > 2):
last = i
break
return last
def QuickFractal(width=200, height=200, amin=-1.5, amax=1.5, bmin=-1.5, bmax=1.5,
iterateMax=100, filename="blah.png"):
(w, h, iMax, iterationValues, histogram) = Fractal(width, height, amin, amax, bmin, bmax, iterateMax)
iFractal = DrawFractal(w, h, iMax, iterationValues, histogram)
iFractal.show()
iFractal.save(filename)
def ShowFractal(frac, filename="test.png"):
(w, h, iMax, iterationValues, histogram) = frac
iFractal = DrawFractal(w, h, iMax, iterationValues, histogram)
iFractal.show()
iFractal.save(filename)
def Fractal(width=200, height=200, amin=-1.5, amax=1.5, bmin=-1.5, bmax=1.5,
iterateMax=100):
aStep = (amax - amin) / (width)
bStep = (bmax - bmin) / (height)
"""compute iteration values"""
iterationValues = np.zeros((width, height))
histogram = np.zeros(iterateMax)
a = amin
for x in range(width):
b = bmin
for y in range(height):
c = np.complex(a,b)
i = IterateFractal(c, iterateMax)
iterationValues[x][y] = i
if not(np.isnan(i)):
histogram[i] = histogram[i] + 1
b = b + bStep
a = a + aStep
return (width, height, iterateMax, iterationValues, histogram)
def DrawFractal(width, height, iterateMax, iterationValues, histogram):
hue = np.zeros(iterateMax)
hue[0] = histogram[0] / sum(histogram)
image = Image.new("RGB", (width, height))
for i in range(1, iterateMax):
hue[i] = hue[i-1] + (histogram[i] / sum(histogram))
"""second pass to draw the values"""
for x in range(width):
for y in range(height):
i_count = iterationValues[x,y]
colorPixel = FractalColor(i_count, iterateMax, hue)
image.putpixel((x,y), colorPixel)
return image
def FractalColor(iterate, maxIt, hues):
if np.isnan(iterate):
return (0,0,0)
else:
grad_point_1 = np.array((0, 0, 255))
grad_point_2 = np.array((255, 255, 255))
grad_point_3 = np.array((0, 0, 255))
hue_mod = (iterate % 100) / 100
if (0 <= hue_mod < (.5)):
delta = grad_point_2 - grad_point_1
start = grad_point_1
if ((.5) <= hue_mod < (1)):
delta = grad_point_3 - grad_point_2
start = grad_point_2
scalar = ((hue_mod * 100) % (100 / 2))/30
arr = np.floor(start + np.dot(scalar, delta))
return tuple(arr.astype(int))
Example output:
-
1\$\begingroup\$ You can get some optimisations from the Wikipedia article \$\endgroup\$ovs– ovs2017年04月22日 21:12:44 +00:00Commented Apr 22, 2017 at 21:12
-
\$\begingroup\$ Please do not update the code in your question to incorporate feedback from answers, doing so goes against the Question + Answer style of Code Review. This is not a forum where you should keep the most updated version in your question. Please see what you may and may not do after receiving answers . \$\endgroup\$Simon Forsberg– Simon Forsberg2017年06月05日 22:08:15 +00:00Commented Jun 5, 2017 at 22:08
1 Answer 1
The key thing to remember when working with numerical code is that the CPython interpreter is pretty slow (it trades speed for flexibility) and so you must avoid running in the interpreter whenever possible. Instead of iterating in slow Python bytecode using for x in ...
, operate on whole arrays by calling the appropriate NumPy function or method, which dispatches to fast compiled code operating on fixed-size numbers.
So to speed up Fractal
we must turn it inside out — instead of operating on one pixel at a time, we must operate on the whole array at once.
Let's start out by measuring the performance of the implementation in the post:
>>> from timeit import timeit >>> timeit(Fractal, number=1) 1.4722420040052384
The constants \$c\$ are computed one at a time:
a = amin for x in range(width): b = bmin for y in range(height): c = np.complex(a,b) # ... b = b + bStep a = a + aStep
Instead, compute a whole array of constants \$c\$ using
numpy.linspace
andnumpy.meshgrid
:a, b = np.meshgrid(np.linspace(amin, amax, width), np.linspace(bmin, bmax, height), sparse=True) c = a + 1j * b
The
IterateFractal
function operates on the \$z\$ value for a single pixel at a time. Instead, we should operate on a whole array of \$z\$ values:z = np.zeros_like(c) iterations = np.zeros_like(c, dtype=int) for i in range(1, iterateMax + 1): z = z ** 2 + c iterations[abs(z) <= 2] = i
Note that in this implementation, points that never escape from the set have
iterations
equal toiterateMax
(rather than NaN as in the post). This requires a corresponding change toFractalColor
:def FractalColor(iterate, maxIt, hues): if iterate >= maxIt: return (0,0,0) else: # ...
Instead of computing
histogram
one pixel at a time, usenumpy.bincount
:histogram = np.bincount(iterations.flatten())
This results in the revised code:
def Fractal2(width=200, height=200, amin=-1.5, amax=1.5, bmin=-1.5, bmax=1.5, iterateMax=100): a, b = np.meshgrid(np.linspace(amin, amax, width), np.linspace(bmin, bmax, height), sparse=True) c = a + 1j * b z = np.zeros_like(c) iterations = np.zeros_like(c, dtype=int) for i in range(1, iterateMax + 1): z = z ** 2 + c iterations[abs(z) <= 2] = i histogram = np.bincount(iterations.flatten()) return width, height, iterateMax, iterations.T, histogram
Note the
iterations.T
at the end there: the.T
property is a shorthand fornumpy.transpose
. This is needed because NumPy prefers row-major ordering where a 2-dimensional array is indexed by row, column, butDrawFractal
is using column-major ordering where an image is indexed by x, y. See the Numpy documentation on "Multidimensional Array Indexing Order Issues". If you updatedDrawFractal
to use row-major indexing then you would avoid the need for this transpose.This is about 40 times as fast as the original code:
>>> timeit(Fractal2, number=1) 0.03542686498258263
There are a couple of problems with the whole-array approach. First, it wastes work: once a \$z\$ value has escaped the set, we don't need to keep iterating it. Second, by continuing to iterate the \$z\$ value, we might find that it overflows the range of floating-point numbers, resulting in unwanted overflow warnings.
So what we can do is to keep track of the indexes of the pixels that have not yet escaped the set, and only operate on the corresponding values of \$z\$:
def Fractal3(width=200, height=200, amin=-1.5, amax=1.5, bmin=-1.5, bmax=1.5, iterateMax=100): a, b = np.meshgrid(np.linspace(amin, amax, width), np.linspace(bmin, bmax, height), sparse=True) c = (a + 1j * b).flatten() z = np.zeros_like(c) ix = np.arange(len(c)) # Indexes of pixels that have not escaped. iterations = np.empty_like(ix) for i in range(iterateMax): zix = z[ix] = z[ix] ** 2 + c[ix] escaped = abs(zix) > 2 iterations[ix[escaped]] = i ix = ix[~escaped] iterations[ix] = iterateMax histogram = np.bincount(iterations) iterations = iterations.reshape((height, width)) return width, height, iterateMax, iterations.T, histogram
Note that when operating on a set of indexes like this, it's most convenient to work with a flattened array, so that we only need one array of indexes. (If we used a two-dimensional array, we'd need to maintain two arrays of indexes
ix
andiy
.) So we callnumpy.flatten
at the start to flatten the array to a single dimension, andnumpy.reshape
at the end to restore it to two dimensions.This is about 70 times as fast as the original code:
>>> timeit(Fractal3, number=1) 0.02115203905850649
Now that you've seen how to make
Fractal
operate on the whole array (rather than one pixel at a time), you should be able to do the same forDrawFractal
. (Hint: instead of callingImage.putpixel
for each pixel, callImage.putdata
once.)
-
\$\begingroup\$ So I finally got around to implementing all of these suggestions. I do have one question: can Image.putdata accept anything other than a one-dimensional array of tuples. \$\endgroup\$user4253– user42532017年06月05日 21:23:09 +00:00Commented Jun 5, 2017 at 21:23
-
\$\begingroup\$ Right now I'm using the code
for i in range(len(hues)): hues[i] = (red[i], green[i], blue[i])
(red, green, and blue are 2d numpy arrays, hues is a 1d python list), but I'm not sure if it's slowing down the code significantly \$\endgroup\$user4253– user42532017年06月05日 21:27:38 +00:00Commented Jun 5, 2017 at 21:27 -
\$\begingroup\$ I went ahead and posted my updated version of the code: I would love it if you took a look. I still haven't implemented point six: I plan to once I spend some more time making sure I understand all of your advice. Thank you for the answer, by the way, it really improved my understanding of python and numpy. \$\endgroup\$user4253– user42532017年06月05日 21:34:43 +00:00Commented Jun 5, 2017 at 21:34