4
\$\begingroup\$

I have the following implementation of the Miller-Rabin test in Python:

from random import randint
import decimal
decimal.getcontext().prec = 10000000
def miller_rabin(n, k):
 if n == 2 or n == 3:
 return True
 elif n == 1 or n <= 0 or str(n).isdigit() == False or n % 2 == 0:
 return False
 else:
 s = 1
 while True:
 if (n - 1) % (2 ** (s + 1)) == 0:
 s += 1
 else:
 break
 d = (n - 1) / (2 ** s)
 for _ in range(k):
 a = randint(2, n - 2)
 x = (decimal.Decimal(a) ** decimal.Decimal(d)) % decimal.Decimal(n)
 for _ in range(s):
 y = x * x % n
 if y == 1 and x != 1 and x != n - 1:
 return False
 x = y
 if y != 1:
 return False
 return True
  • n is the number to be tested,
  • k is the number of rounds of testing to perform,
  • s and d are used to write n in the form \2ドル^s*d\$, and
  • a is a base such that \2ドル \le a \le n−2\$.

decimal is there to accommodate some of the resulting decimal values that may appear.

It works fine from what I've tested thus far. However, I'm pretty sure it's not very efficient. How can I optimize my code?

Peilonrayz
44.4k7 gold badges80 silver badges157 bronze badges
asked Aug 24, 2023 at 13:48
\$\endgroup\$
1
  • \$\begingroup\$ Why should this accommodate strings? I'd expect that this function assumes integers. Or actually, it seems like you're checking for whole numbers or not, and there are probably saner ways to do that \$\endgroup\$ Commented Aug 24, 2023 at 14:24

2 Answers 2

4
\$\begingroup\$

Bug: decimal exponentiation (yes, decimal exponentiation is a bug)

This is wrong:

x = (decimal.Decimal(a) ** decimal.Decimal(d)) % decimal.Decimal(n)

Using very big decimals (prec = 10000000) may make this appear to work for some values. But there is no fixed precision that is sufficient for any input, for any fixed precision you can find an input such that a ** d is so large that it exceeds that precision. Apart from sometimes breaking, it's also always a performance bug. This line should be:

x = pow(a, d, n)

This modular exponentiation does not need huge precision, it only works with values that are about as big as n. So it should be efficient.

If it seems like d could be a float, well, it can't.

d = (n - 1) / (2 ** s)

The division should be integer division, or bitwise shift. 2s by definition divides n - 1: the value of s is defined to be the highest integer such that 2s divides n - 1 (the point is to divide the powers of 2 out of n - 1).

x * x % n

This could also be pow(x, 2, n) but here it's not critical.

answered Aug 24, 2023 at 15:35
\$\endgroup\$
2
  • \$\begingroup\$ So the definition of d should use // instead of /, right? \$\endgroup\$ Commented Aug 24, 2023 at 16:48
  • \$\begingroup\$ @pointySphere yes, or right-shift by s directly. Alternatively, you can find d and s simultaneously by right-shifting d in the loop that finds s (that also simplifies the loop exit condition) \$\endgroup\$ Commented Aug 24, 2023 at 16:52
3
\$\begingroup\$

Use the built-in is_integer instead of string coercion to check for integrality.

n == 1 or n <= 0 is just n <= 1.

Any time there's a return, you don't need to follow it with an else or elif.

It's important that you accept an optional random generator argument, for the purposes of reproducibility in unit tests. Also: add unit tests.

Never indent at two spaces; use 4.

It's unsurprising to see single letter variables in a math routine, but you should comment them all. (You still haven't offered explanations for x and y).

When factoring 2, you needn't exponentiate.

Don't use Decimal; Python has arbitrary-precision integers by default and as in the comments, Miller-Rabin is integral throughout.

Suggested

import random
from typing import Callable
def factor_2(n: int) -> tuple[int, int]:
 """
 Factor powers of 2 from n-1
 :param n: Mill-Rabin test subject; n must be odd.
 :return: s, d such that n-1 = 2**s * d (for odd d)
 >>> factor_2(5)
 (2, 1)
 >>> s, d = factor_2(67)
 >>> 2**s * d
 66
 >>> d % 2
 1
 >>> s, d = factor_2(1201)
 >>> 2**s * d
 1200
 >>> d % 2
 1
 """
 s = 0
 d = n - 1
 while d % 2 == 0:
 s += 1
 d //= 2
 return s, d
def miller_rabin(
 n: float | int,
 n_rounds: int,
 randint: Callable[[int, int], int] = random.randint,
) -> bool:
 """
 Stochastic primality test.
 :param n: Number to be tested
 :param n_rounds: Maximum number of rounds to run before assuming prime.
 Called 'k' in some literature.
 :param randint: Optional, alternative randint for reproducibility
 :return: True if probably prime, false if definitely composite
 """
 if n == 2 or n == 3: # Fast path for small primes
 return True
 if n <= 1 or n % 2 == 0: # Fast path for small or even composites
 return False
 if isinstance(n, float):
 if n.is_integer():
 n = int(n) # n is a whole float; coerce to int and continue
 else:
 return False # n is non-integral, so composite by definition
 s, d = factor_2(n)
 for _ in range(n_rounds):
 # base such that 2 <= a <= n − 2
 a = randint(2, n - 2)
 x = pow(base=a, exp=d, mod=n)
 for _ in range(s):
 y = x*x % n
 if y == 1 and x != 1 and x != n - 1:
 return False
 x = y
 if y != 1:
 return False
 return True
def test() -> None:
 rand = random.Random(x=0)
 # Test first 20 numbers for primality
 for n, should_be_prime in enumerate((
 # 0 1 2 3 4 5 6 7 8 9
 0,0,1,1,0, 1,0,1,0,0,
 # 1011121314 1516171819
 0,1,0,1,0, 0,0,1,0,1,
 )):
 actual = miller_rabin(n=n, n_rounds=10, randint=rand.randint)
 assert actual == should_be_prime
 # Test large prime and composite, int and float forms
 assert miller_rabin(n=5264191, n_rounds=10, randint=rand.randint)
 assert miller_rabin(n=5264191., n_rounds=10, randint=rand.randint)
 assert not miller_rabin(n=5264193, n_rounds=10, randint=rand.randint)
 assert not miller_rabin(n=5264193., n_rounds=10, randint=rand.randint)
 assert not miller_rabin(n=5264193.1, n_rounds=10, randint=rand.randint)
 # Test to see that the function runs OK with the default rand implementation
 assert miller_rabin(n=2, n_rounds=1)
 assert isinstance(miller_rabin(n=2, n_rounds=3), bool)
if __name__ == '__main__':
 test()
answered Aug 24, 2023 at 15:17
\$\endgroup\$
0

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.