I'm working on a primality test and have written a recursive function that returns the value of the function
\$b^{q-1} \bmod q\$
where \3ドル<= q <= 32000\$
Is there any way to speed up my function? It works, but takes a while to return the answer as \$q\$ approaches 32000.
Variables:
pow
= \$q-1\$
mod
= \$q\$
b
is a variable ranging from \1ドル < b < q \$
If q is prime, then b will be = to q if not, b will be a "strong" feature of non-primality. See Miller–Rabin primality test.
public int fF(int q)
{
int b = 2, v = 0;
while(b < q)
{
v = operate(b, q-1, q);
if (v != 1)
break;
b++;
}
return b;
}
int operate(int b, int pow, int mod)
{
if (pow == 2)
return (b * b) % mod;
return (pow % 2 != 0) ? (b * operate(b, pow - 1, mod)) % mod : (operate(b, pow / 2, mod) * operate(b, pow / 2, mod)) % mod;
}
4 Answers 4
Naming
Oh my, what would Mr.Maintainer think if he would inherit the code... single letter variable names, methodnames like fF. He would have a hard time to figure out what is happening.
So let us clean this a little bit
public int fF(int possiblePrime)
{
int baseNumber = 2, v = 0;
while (baseNumber < possiblePrime)
{
int exponent = possiblePrime - 1;
v = operate(baseNumber, exponent, possiblePrime);
if (v != 1)
break;
baseNumber++;
}
return baseNumber;
}
int operate(int baseNumber, int exponent, int divisor)
{
if (exponent == 2)
return (baseNumber * baseNumber) % divisor;
return (exponent % 2 != 0) ? (baseNumber * operate(baseNumber, exponent - 1, divisor)) % divisor : (operate(baseNumber, exponent / 2, divisor) * operate(baseNumber, exponent / 2, divisor)) % divisor;
}
Style
As many will agree, using braces {}
, for single if
statements also, is a have to.So let us use them and also let us remove the tenary expression and add an int result
which we will return
public int fF(int possiblePrime)
{
int baseNumber = 2, v = 0;
while (baseNumber < possiblePrime)
{
int exponent = possiblePrime - 1;
v = operate(baseNumber, exponent, possiblePrime);
if (v != 1)
{
break;
}
baseNumber++;
}
return baseNumber;
}
int operate(int baseNumber, int exponent, int divisor)
{
int result = 0;
if (exponent == 2)
{
result = (baseNumber * baseNumber) % divisor;
}
else if (exponent % 2 != 0)
{
result = (baseNumber * operate(baseNumber, exponent - 1, divisor)) % divisor;
}
else
{
result = (operate(baseNumber, exponent / 2, divisor) * operate(baseNumber, exponent / 2, divisor)) % divisor;
}
return result;
}
Refactoring
Now let us focus on operate()
What you are doing, is always calling number * number % divisor
so let us extract this to a method
private int calculateProductModulo(int firstValue, int secondValue, int moduloNumber)
{
return (firstValue * secondValue) % moduloNumber;
}
The operate()
method now looks
int operate(int baseNumber, int exponent, int divisor)
{
int result = 0;
if (exponent == 2)
{
result = calculateProductModulo(baseNumber, baseNumber, divisor);
}
else if (exponent % 2 != 0)
{
result = calculateProductModulo(baseNumber, operate(baseNumber, exponent - 1, divisor), divisor);
}
else
{
result = calculateProductModulo(operate(baseNumber, exponent / 2, divisor), operate(baseNumber, exponent / 2, divisor), divisor);
}
return result;
}
If we now extract the recursive calls out of the call to calculateProductModulo()
we will see clearly what you have stated in your answer
int operate(int baseNumber, int exponent, int divisor)
{
int result = 0;
if (exponent == 2)
{
result = calculateProductModulo(baseNumber, baseNumber, divisor);
}
else if (exponent % 2 != 0)
{
int recursiveResult = operate(baseNumber, exponent - 1, divisor);
result = calculateProductModulo(baseNumber, recursiveResult, divisor);
}
else
{
int recursiveResult1 = operate(baseNumber, exponent / 2, divisor);
int recursiveResult2 = operate(baseNumber, exponent / 2, divisor);
result = calculateProductModulo(recursiveResult1, recursiveResult2, divisor);
}
return result;
}
The code is calling 2 times the same method with the same arguements.
Let us eleminate the double calling
int operate(int baseNumber, int exponent, int divisor)
{
int result = 0;
int recursiveResult = 0
if (exponent == 2)
{
result = calculateProductModulo(baseNumber, baseNumber, divisor);
}
else if (exponent % 2 != 0)
{
recursiveResult = operate(baseNumber, exponent - 1, divisor);
result = calculateProductModulo(baseNumber, recursiveResult, divisor);
}
else
{
recursiveResult = operate(baseNumber, exponent / 2, divisor);
result = calculateProductModulo(recursiveResult , recursiveResult , divisor);
}
return result;
}
The idea of fast exponentiation is corrupted by the following statement
result=operate(b, pow / 2, mod) * operate(b, pow / 2, mod) # (1)
instead of
result=operate(b, pow / 2, mod)**2
or
aux=operate(b, pow / 2, mod)
result=(aux*aux)%mod
It actually slows down the performance from \$O(log(\text{pow})\$ multiplication to \$pow-1\$ multiplications. This is the performance of the dumb exponentiation algorithm (multiplying \$b\$ \$e-1$ times by itself. The performance gain comes from avoiding this second evaluation in (1).
I've realized my problem, so I've changed
return (pow % 2 != 0) ? (b * operate(b, pow - 1, mod)) % mod : (operate(b, pow / 2, mod) * operate(b, pow / 2, mod)) % mod;
to
return (pow % 2 != 0) ? (b * operate(b, pow - 1, mod)) % mod : (int)Math.Pow(operate(b, pow / 2, mod)),2) % mod;
In the first return I was calling the recursive function twice and then evaluating the square. Instead, in the second, I call it once and then evaluate the square. It runs much faster now.
The code isn't tail recursive so you actually have to pay for the recursion with stack and function calls.
You can turn this iterative if you start at the highest bit and work your way down to the lowest. Computing a^d mod n
becomes
using std::int32_t;
using std::uint32_t;
using std::uint64_t;
uint32_t pow_n(uint32_t a, uint32_t d, uint32_t n) {
if (d == 0) __builtin_unreachable();
unsigned shift = std::countl_zero(d) + 1;
uint32_t t = a;
int32_t m = d << shift;
for (unsigned i = 32 - shift; i > 0; --i) {
t = ((uint64_t)t * t) % n;
if (m < 0) t = ((uint64_t)t * a) % n;
m <<= 1;
}
return t;
}
Note: std::countl_zero(d)
is C++20 but available as compiler intrinsic before that (or in C) in many compilers. Modern CPUs tend to have opcodes that make this faster than bit-banging the result manually.
Note2: The code use the sign of m
to extract the highest bit. can change that to extract bits at i
instead of shifting m
.
Note3: the code works for all uint32_t. If your numbers are small (<65536) then you could avoid the uint64_t
casts.
-
1\$\begingroup\$ It would be more instructive to stick with the language of the original question code (i.e. C#), rather than presenting a C++ function, I think. \$\endgroup\$Toby Speight– Toby Speight2022年07月05日 06:46:00 +00:00Commented Jul 5, 2022 at 6:46
-
\$\begingroup\$ @TobySpeight Sorry, if you know enough c# to rewrite the code then please do. \$\endgroup\$Goswin von Brederlow– Goswin von Brederlow2022年07月05日 08:52:44 +00:00Commented Jul 5, 2022 at 8:52
return
statement... \$\endgroup\$