A number chain is created by continuously adding the square of the digits in a number to form a new number until it has been seen before.
For example,
44 -> 32 -> 13 -> 10 -> 1 -> 1
85 -> 89 -> 145 -> 42 -> 20 -> 4 -> 16 -> 37 -> 58 -> 89
Therefore any chain that arrives at 1 or 89 will become stuck in an endless loop. What is most amazing is that EVERY starting number will eventually arrive at 1 or 89.
How many starting numbers below ten million will arrive at 89?
This solution to Project Euler problem 92 takes about 80 seconds. How can I reduce the time to well under 60 seconds?
def squareDigits(num):
total = 0
for digit in str(num):
total += int(digit) ** 2
return total
pastChains = [None] * 10000001
pastChains[1], pastChains[89] = False, True
for num in range(2, 10000001):
chain = [num]
while pastChains[chain[-1]] is None:
chain.append(squareDigits(chain[-1]))
for term in chain:
if pastChains[term] is not None:
pastChains[term] = pastChains[chain[-1]]
print pastChains.count(True)
-
1\$\begingroup\$ Memoization might help you \$\endgroup\$SylvainD– SylvainD2013年09月24日 20:52:33 +00:00Commented Sep 24, 2013 at 20:52
-
\$\begingroup\$ There is a nice combinatorial approach, described here. \$\endgroup\$Vedran Šego– Vedran Šego2013年09月24日 21:08:19 +00:00Commented Sep 24, 2013 at 21:08
2 Answers 2
1. Bug
There's a bug in your program. This line
if pastChains[term] is not None:
should be
if pastChains[term] is None:
2. Piecewise improvement
In this section I'm going to show how you can speed up a program like this by making a series of small piecewise improvements. This doesn't often work for the Project Euler problems: normally you have to come up with new algorithmic ideas. But here you're pretty close to getting under the minute mark, and that makes piecewise improvement a plausible approach.
Note that I'm using Python 2.7 here, which is generally faster than Python 3.
2.1. Base case
I'm going to start by putting your code into a form which is more convenien for testing (also, where we can run it for smaller limits than ten million, which can be convenient when making frequent measurement). Also, I'll improve the coding style while I'm about it (add docstrings; put local variables in lower_case_with_underscores
; use the more meaningful values 1
and 89
instead of the Booleans False
and True
).
def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
for digit in str(n):
total += int(digit) ** 2
return total
def problem92a(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = [n]
while arrive[chain[-1]] is None:
chain.append(square_digits(chain[-1]))
for term in chain:
if arrive[term] is None:
arrive[term] = arrive[chain[-1]]
return arrive.count(89)
>>> from timeit import timeit
>>> timeit(lambda:problem92a(10**7), number=1)
185.13595986366272
That's a lot slower than your reported "80 seconds": clearly you have a much faster machine than I do!
2.2. Avoiding number/string conversion
Now, let's do some work on the square_digits
function. As written, this function has to convert the integer n
to a string and then convert each digit back to a number. We could avoid these conversions by working with numbers throughout:
def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
while n:
total += (n % 10) ** 2
n //= 10
return total
>>> timeit(lambda:problem92a(10**7), number=1)
61.409788846969604
Nearly under the minute already!
2.3. Second version
Here are some obvious minor improvements:
Avoid the repeated lookup of
chain[-1]
by remembering the value in a local variable.Reduce the size of
chain
by one (because the last element in the chain is remembered in the local variable).The test
arrive[term] is None
is unnecessary: by this point in the code we know that only the last term inchain
was found inarrive
.
That yields the following code:
def problem92b(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
return arrive.count(89)
When transforming code like this, it's always important to check that our transformations didn't break anything. Here's where the ability to run the program for small values of limit
comes in handy:
>>> all(problem92a(i) == problem92b(i) for i in range(1000, 2000))
True
And this yields a further 12% speedup:
>>> timeit(lambda:problem92b(10**7), number=1)
53.771003007888794
2.3. Third version
The largest number under ten million is 9999999, whose sum of squares of digits is just 567. All other numbers in the range have even smaller sums of squares of digits. So for 568 and up, there is no need to follow the chain: we can just look up the answer directly. That suggests the following approach:
def problem92c(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
sum_limit = len(str(limit - 1)) * 9 ** 2 + 1
arrive = [None] * sum_limit
arrive[1], arrive[89] = 1, 89
for n in range(2, sum_limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
c = arrive.count(89)
for n in range(sum_limit, limit):
c += arrive[square_digits(n)] == 89
return c
Again, we better check that we didn't break anything:
>>> all(problem92a(i) == problem92c(i) for i in range(1000, 2000))
True
And this yields a further 30% improvement:
>>> timeit(lambda:problem92c(10**7), number=1)
37.56105399131775
That's about an 80% speedup on the original version, just by making piecewise and localized improvements.
3. New algorithm
For more radical improvements to the runtime, we need to rethink the algorithm completely. Here's a sketch of the combinatorial approach:
For each number k up to 567 that arrives at 89, we work out all the ways to partition k into a sum of no more than seven squares from 12 to 92 (using the partitioning technique we used in our answer to problem 76, say). This gives us a multiset of digits, and we can use combinatorial techniques to count the number of ways this multiset can be realised as a number up to 9999999.
For example, the starting number 4 arrives at 89. 4 can be partitioned into 22 or into 12 + 12 + 12 + 12. In the first case, there are 7 starting numbers in the range (2, 20, 200, 2000, 20000, 200000 and 2000000). In the second case, writing C(n, k) for the number of combinations of k items out of n, there are C(3,3) + C(4,3) + C(5,3) + C(6,3) = 35 starting numbers in the range (1111, 10111, 11011, and so on).
And here's some code, which I hope is self-documenting! Ask if you don't understand anything.
It uses the @memoized
decorator from the Python Decorator Library to avoid unnecessary re-computation.
from collections import Counter
from math import factorial
from memoize import memoized
@memoized
def partitions(n, k, v):
"""Return partitions of n into at most k items from v, with
repetition. v must be a tuple sorted into numerical order. Each
partition is returned as multiset in the form of a Counter object
mapping items from v to the number of times they are used in the
partition.
>>> partitions(4, 7, (1, 4))
[Counter({1: 4}), Counter({4: 1})]
"""
if n == 0:
# Base case: the empty partition.
return [Counter()]
if k == 0 or len(v) == 0 or n < v[0]:
# No partitions possible here.
return []
pp = [p.copy() for p in partitions(n - v[0], k - 1, v)]
for p in pp:
p[v[0]] += 1
return pp + partitions(n, k, v[1:])
@memoized
def multinomial(n, k):
"""Return the multinomial coefficient n! / k[0]! k[1]! ... k[m]!.
>>> multinomial(6, (2, 2, 2))
90
"""
result = factorial(n)
for i in k:
result //= factorial(i)
return result
@memoized
def number_count(digit_counts, min_digits, max_digits):
"""Return the count of numbers (with between min_digits and max_digits
inclusive) whose distinct non-zero digits have counts given by the
sequence digit_counts. For example if we have three identical
non-zero digits and four digits in total:
>>> number_count((3,), 4, 4)
3
because the possible numbers resemble 1011, 1101, and 1110.
Similarly
>>> number_count((1,1), 4, 4)
6
because the possible numbers resemble 1002, 1020, 1200, 2001,
2010, and 2100.
"""
nonzero_digits = sum(digit_counts)
total = 0
for digits in range(max(min_digits, nonzero_digits), max_digits + 1):
for i, d in enumerate(digit_counts):
counts = (digit_counts[:i] + (d - 1,) + digit_counts[i+1:]
+ (digits - nonzero_digits,))
total += multinomial(digits - 1, tuple(sorted(counts)))
return total
def problem92d(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
max_digits = len(str(limit - 1))
assert(limit == 10 ** max_digits) # algorithm works for powers of 10 only
sum_limit = max_digits * 9 ** 2 + 1
arrive = [None] * sum_limit
arrive[1], arrive[89] = 1, 89
for n in range(2, sum_limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
total = 0
squares = tuple(i ** 2 for i in range(1, 10))
for n in range(2, sum_limit):
if arrive[n] == 89:
for p in partitions(n, max_digits, squares):
total += number_count(tuple(sorted(p.values())), 1, max_digits)
return total
As usual, we need to check that this radically new approach computes the right results:
>>> all(problem92c(10**i) == problem92d(10**i) for i in range(3, 7))
True
Note that since we are using memoization, for fair timing results we must reload the code before running the timer (otherwise we'll get misleadingly fast times due to some of the computation having previously been done and stored in the memoization caches). But with a fresh Python instance:
>>> timeit(lambda:problem92d(10**7), number=1)
0.7318389415740967
Less than a second! I hope that it's clear now how important the choice of algorithm is. With the piecewise improvement approach the code became about five times faster (which is pretty decent result). But with a better algorithm the code is 250 times faster.
4. Questions in comments
You asked some questions in the comments.
Obviously in Python 3 you'd use
range
, notxrange
. And in fact, even in Python 2.7, substitutingxrange
forrange
makes hardly any difference to the runtime (a percent or so), so in my revised answer I've eschewed that.In Python, small integer objects like
1
and89
are shared and reused:>>> x = 89 >>> x is 8900 / 100 True
If you look at the source code for
longobject.c
, you'll see that numbers fromNSMALLNEGINTS
(−5) up toNSMALLPOSINTS-1
(256) are preallocated in an array. So there is no memory penalty to using them.Iteration works in Python by fetching successive values from an "iterator" object (here the object returned by
range(2, limit)
), not by incrementing the loop variable as you might do in languages like C. So there is no need to be scared about updating the loop variable.Yes, that's right. Project Euler asks "How many starting numbers below ten million will arrive at 89?" (my emphasis). So it's convenient to use
limit
like this.Probably a copy-paste mistake on my part. I've revised the answer and hopefully the sequence of improvements is clearer now.
-
\$\begingroup\$ Thanks so much! Some questions: 1. I understand that
xrange()
cannot be used in Python 3? 2. Would storing integers in the list instead of Booleans take up much more memory? 3. Wouldn = square_digits(n)
affect the iteration? 4.range(x, y)
iterates from x to y - 1, sosquare_digits(limit)
would be skipped? 5. Thetimeit
for 92a is2.3686327934265137
and thetimeit
for 92b is3.6810498237609863
. How is that a 25% speedup? \$\endgroup\$asp– asp2013年09月25日 00:57:21 +00:00Commented Sep 25, 2013 at 0:57 -
\$\begingroup\$ I wonder if psyco (or similar tools) could speed it up even further. Though such tools may have a slight warm-up, so they might be ineffective on your final solution. \$\endgroup\$Brian– Brian2013年09月26日 14:24:16 +00:00Commented Sep 26, 2013 at 14:24
-
\$\begingroup\$ @Brian: I find that even allowing for warmup,
problem92d
runs slower in PyPy (about 1.0 s) than in Python 2.7 (about 0.7 s). \$\endgroup\$Gareth Rees– Gareth Rees2013年09月26日 14:33:31 +00:00Commented Sep 26, 2013 at 14:33
Small Hints
Do away with your squareDigits()
function — it's time consuming, and performs a lot of work that could be more efficiently derived from results of squareDigits()
of smaller numbers.
Don't bother distinguishing between pastChains
and the current chain
of interest. In this problem, all chains are the same — given any number, its chain always terminates the same way. Be like an elephant — never forget the result of any calculation.
Spoiler alert!
My solution completes in less time than by @GarethRees: 10 seconds on my machine. The solution relies heavily on memoization.
In particular, computing the sum of the squares of the digits of large numbers is quite time consuming. My solution never needs to do that: it just adds the square of the ones' digit to previously calculated sum of the squares of the more significant digits. It also doesn't bother storing any chains; it only needs to keep the furthest known reduced value of each chain. The reduction procedure also doesn't require any calculation, only array lookups.
def euler92(limit):
# Let sumsq uphold the invariant that the nth element contains
# the sum of the squares of the digits of n, or better yet, a
# value somewhere along its reduction chain.
# Start with the base case: chain[0] = sum(0 * 0) = 0.
# Also preallocate many unknowns...
sumsq = [sum((0 * 0,))] + [None] * limit
# ... and fill them in. Note how we reuse previous sums!
for i in xrange(1 + limit):
sumsq[i] = (i % 10) ** 2 + sumsq[i // 10]
# Keep reducing each element until everything has converged
# on either 1 or 89.
all_converged = False
while not all_converged:
all_converged, eighty_nines = True, 0
for i in xrange(1, 1 + limit):
if sumsq[i] == 1:
pass
elif sumsq[i] == 89:
eighty_nines += 1
else:
all_converged = False
# sumsq[sumsq[i]] is a quick way to calculate
# the sum of the squares of the digits, and maybe
# even skip a few steps down the chain.
sumsq[i] = sumsq[sumsq[i]]
return eighty_nines
print euler92(10000000)
-
\$\begingroup\$ Nice! (I find that
sum(0 * 0)
raisesTypeError: 'int' object is not iterable
, but with that fixed, I make this about twice as fast as myproblem92c
— you must have a much faster machine than me.) \$\endgroup\$Gareth Rees– Gareth Rees2013年09月25日 12:38:24 +00:00Commented Sep 25, 2013 at 12:38 -
\$\begingroup\$ @GarethRees I like that this code is still recognizable as a solution to the original problem. Anyway, I've relaxed my speed claims and fixed the sum(0 * 0) — Thanks! \$\endgroup\$200_success– 200_success2013年09月25日 15:02:16 +00:00Commented Sep 25, 2013 at 15:02