I did an implementation of the Tonelli-Shanks algorithm as defined on Wikipedia. I put it here for review and sharing purpose.
Legendre Symbol implementation:
def legendre_symbol(a, p):
"""
Legendre symbol
Define if a is a quadratic residue modulo odd prime
http://en.wikipedia.org/wiki/Legendre_symbol
"""
ls = pow(a, (p - 1)/2, p)
if ls == p - 1:
return -1
return ls
Prime modular square root (I just renamed the solution variable R to x and n to a):
def prime_mod_sqrt(a, p):
"""
Square root modulo prime number
Solve the equation
x^2 = a mod p
and return list of x solution
http://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm
"""
a %= p
# Simple case
if a == 0:
return [0]
if p == 2:
return [a]
# Check solution existence on odd prime
if legendre_symbol(a, p) != 1:
return []
# Simple case
if p % 4 == 3:
x = pow(a, (p + 1)/4, p)
return [x, p-x]
# Factor p-1 on the form q * 2^s (with Q odd)
q, s = p - 1, 0
while q % 2 == 0:
s += 1
q //= 2
# Select a z which is a quadratic non resudue modulo p
z = 1
while legendre_symbol(z, p) != -1:
z += 1
c = pow(z, q, p)
# Search for a solution
x = pow(a, (q + 1)/2, p)
t = pow(a, q, p)
m = s
while t != 1:
# Find the lowest i such that t^(2^i) = 1
i, e = 0, 2
for i in xrange(1, m):
if pow(t, e, p) == 1:
break
e *= 2
# Update next value to iterate
b = pow(c, 2**(m - i - 1), p)
x = (x * b) % p
t = (t * b * b) % p
c = (b * b) % p
m = i
return [x, p-x]
If you have any optimization or have found any error, please report it.
3 Answers 3
Good job! I don't have a lot to comment on in this code. You have written straightforward clear code whose only complexity stems directly from the complexity of the operation it is performing. It would be good to include some of your external commentary (such as the renames of R
and n
) in the code itself to make it easier for someone to follow the documentation on wikipedia. You may want to include some of that documentation as well.
For reference, the rest of this review assumes that the code functions correctly; I don't have my math head on tonight.
There appears to be one case of redundant code, unless m
can ever be 1
, resulting in an empty range and thus no reassignment of i
. Otherwise you can skip the assignment to i
in the following:
i, e = 0, 2
for i in xrange(1, m):
...
There are a number of small strength-reduction optimizations you might consider, but in Python their impact is likely to be minimized - definitely profile before heading too deeply down the optimization path. For example in the following while loop:
# Factor p-1 on the form q * 2^s (with Q odd)
q, s = p - 1, 0
while q % 2 == 0:
s += 1
q //= 2
Both operations on q
can be reduced. The modulus can be rewritten as a binary and q & 1
, and the division as a binary shift q >>= 1
. Alternately, you can use divmod to perform both operations at once.
Similarly, 2**(m - i - 1)
is identical to 1 << (m - i - 1)
for non-negative exponents.
-
\$\begingroup\$ Thanks for the review, this is a really nice feedback. I will remove i assignment (I always though i was only defined inside the for loop). I will also change the
2**(m - i - 1)
. I prefer let other micro-optimization unchanged in order to have better code readability. \$\endgroup\$Phong– Phong2014年03月03日 15:25:48 +00:00Commented Mar 3, 2014 at 15:25 -
\$\begingroup\$ Btw, i did not see how you would use divmod in the code \$\endgroup\$Phong– Phong2014年03月03日 15:28:57 +00:00Commented Mar 3, 2014 at 15:28
-
\$\begingroup\$ It would look something like
while True: next_q, q_odd = divmod(q, 2); if q_odd: break; s += 1; q = next_q
. The tricky question is whether avoiding the separate divisions would pay for the extra operations in Python, not to mention the tradeoffs in terms of clarity. \$\endgroup\$Michael Urman– Michael Urman2014年03月04日 03:18:47 +00:00Commented Mar 4, 2014 at 3:18
To enhance portability to python 3 use //
instead of /
everywhere in your code. You already do this in lines like q //= 2
, but not in lines like x = pow(a, (p + 1)/4, p)
. In fact, consider including from __future__ import division
.
Also, it seems that in a few benchmarks I did computing 2**x
was significantly slower than computing the equivalent 1<<x
. So that is another minor optimization that can be made.
Finally, again for portability to python 3, you can replace the one use of xrange
with range
. I do not think there will be any significant performance loss in python 2 in this particular case.
-
\$\begingroup\$ I confirm that the above code does not run on Python 3. \$\endgroup\$Svetlin Nakov– Svetlin Nakov2017年12月25日 22:10:52 +00:00Commented Dec 25, 2017 at 22:10
I know, this is a bit late but I have some more minor optmimization suggestions:
- in your
legendre_symbol
implementation, you computepow(a, (p - 1)/2, p)
. You don't need to subtract1
fromp
, sincep
is odd. Also, you can replacep/2
withp >> 1
, which is faster. - in your simple case handling, you can replace
p % 4
withp & 3
and again,pow(a, (p + 1)/4, p)
withpow(a, (p + 1) >> 2, p)
. Since you have checked thatp & 3 == 3
, an equivalent solution would bepow(a, (p >> 2) + 1, p)
, I would go for this one instead. It can make a difference when the right shift effectively reduces the byte size ofp
. - there is another simple case you can check for:
p % 8 == 5
or the equivalentp & 7 == 5
. In that case, you can computepow(a, (p >> 3) + 1, p)
, check if it is a solution (it is a solution if and only ifa
is quartic residue modulop
), otherwise multiply that withpow(2, p >> 2, p)
to get a valid solution (and don't forget to calculate% p
after the multiplication of course) in your
while
-loop, you need to find a fittingi
. Let's see what your implementation is doing there ifi
is, for example,4
:pow(t, 2, p) pow(t, 4, p) # calculates pow(t, 2, p) pow(t, 8, p) # calculates pow(t, 4, p), which calculates pow(t, 2, p) pow(t, 16, p) # calculates pow(t, 8, p), which calculates pow(t, 4, p), which calculates pow(t, 2, p)
do you see the redundancy? with increasing
i
, the number of multiplications grows quadratically, while it could just grow linear:i, t2i, = 0, t for i in range(1, m): t2i = t2i * t2i % p if t2i == 1: break
the last optimization is a rather simple one: I would just replace
t = (t * b * b) % p c = (b * b) % p
with
c = (b * b) % p t = (t * c) % p
which saves one multiplication
Explore related questions
See similar questions with these tags.