Polynomial.java - multiplication algorithms: naïve, Karatsuba, FFT
Intro
This post is a supplement to Polynomial.java - a Java class for dealing with polynomials with BigDecimal coefficients. It presents the three polynomial multiplication algorithms. Given two polynomials: one of degree \$n_1\$, and another one of degree \$n_2\$, the multiplication algorithms are:
- Naïve: running in \$\Theta(n_1 n_2)\$,
- Karatsuba: running in \$\Theta(\max(n_1, n_2)^{1.58})\$,
- FFT: running in \$\Theta(N \log N)\$, where \$N\$ is the closest (from above) power of \2ドル\$ to the \$\max(n_1, n_2)\$.
Code
com.github.coderodde.math.PolynomialMultiplier.java:
package com.github.coderodde.math;
import java.math.BigDecimal;
import java.util.Arrays;
/**
* This class contains some polynomial multiplication algorithms.
*
* @version 1.1.1 (Nov 24, 2024)
* @since 1.0.0 (Nov 22, 2024)
*/
public final class PolynomialMultiplier {
public static final BigDecimal DEFAULT_EPSILON = BigDecimal.valueOf(0.01);
public static final int DEFAULT_SCALE = 6;
/**
* Multiplies {@code p1} and {@code p2} naively in {@code O(NM)} time.
*
* @param p1 the first polynomial.
* @param p2 the second polynomial.
*
* @return the product of {@code p1} and {@code p2}.
*/
public static Polynomial multiplyViaNaive(final Polynomial p1,
final Polynomial p2) {
final int coefficientArrayLength = p1.length()
+ p2.length()
- 1;
final BigDecimal[] coefficientArray =
new BigDecimal[coefficientArrayLength];
// Initialize the result coefficient array:
Arrays.fill(
coefficientArray,
0,
coefficientArrayLength,
BigDecimal.ZERO);
for (int index1 = 0;
index1 < p1.length();
index1++) {
for (int index2 = 0;
index2 < p2.length();
index2++) {
final BigDecimal coefficient1 = p1.getCoefficient(index1);
final BigDecimal coefficient2 = p2.getCoefficient(index2);
coefficientArray[index1 + index2] =
coefficientArray[index1 + index2]
.add(coefficient1.multiply(coefficient2));
}
}
return new Polynomial(coefficientArray);
}
/**
* Multiplies the two input polynomials in {@code O(N^(1.58)} time where
* {@code N} is the degree-bound of both {@code p1} and {@code p2}. This
* implementation is adopted from
* <a href="https://www.cs.dartmouth.edu/~deepc/LecNotes/cs31/lec6.pdf">
* these lecture slides</a>.
*
* @param p1 the first polynomial.
* @param p2 the second polynomial.
* @param epsilon the epsilon value for detecting zero coefficients.
*
* @return the product of the two input polynomials.
*/
public static Polynomial multiplyViaKaratsuba(Polynomial p1,
Polynomial p2,
final BigDecimal epsilon) {
final int n = Math.max(p1.length(),
p2.length());
p1 = p1.setLength(n);
p2 = p2.setLength(n);
final Polynomial rawPolynomial = multiplyViaKaratsubaImpl(p1, p2);
return rawPolynomial.minimizeDegree(epsilon);
}
/**
* Delegates the multiplication
* {@link #multiplyViaKaratsuba(com.github.coderodde.math.Polynomial, com.github.coderodde.math.Polynomial, java.math.BigDecimal)}.
* Uses the default epsilon value of {@code 0.01} for zero comparisons.
*
* @param p1 the first polynomial.
* @param p2 the second polynomial.
*
* @return the product of the two input polynomials.
*/
public static Polynomial multiplyViaKaratsuba(Polynomial p1,
Polynomial p2) {
return multiplyViaKaratsuba(p1, p2, DEFAULT_EPSILON);
}
/**
* Multiplies the two input polynomials in {@code O(N log N} time where
* {@code N} is the degree-bound of both {@code p1} and {@code p2}. This
* implementation is adopted from
* <a href="https://archive.org/details/introduction-to-algorithms-third-edition-2009/page/898/mode/1up?view=theater">
* Introduction to Algorirthms, 3rd edition, Chapter 30</a>.
*
* @param p1 the first polynomial.
* @param p2 the second polynomial.
* @param scale the scale to use for {@link java.math.BigDecimal} values.
* @param epsilon the epsilon value for detecting zero coefficients.
*
* @return the product of the two input polynomials.
*/
public static Polynomial multiplyViaFFT(Polynomial p1,
Polynomial p2,
final int scale,
final BigDecimal epsilon) {
final int length = Math.max(p1.length(),
p2.length());
final int n = Utils.getClosestUpwardPowerOfTwo(length) * 2;
p1 = p1.setLength(n);
p2 = p2.setLength(n);
ComplexPolynomial a = computeFFT(new ComplexPolynomial(p1), scale);
ComplexPolynomial b = computeFFT(new ComplexPolynomial(p2), scale);
ComplexPolynomial c = multiplyPointwise(a, b);
ComplexPolynomial r = computeInverseFFT(c, scale);
return r.convertToPolynomial().minimizeDegree(epsilon);
}
/**
* Delegates the multiplication
* {@link #multiplyViaFFT(com.github.coderodde.math.Polynomial, com.github.coderodde.math.Polynomial, java.math.BigDecimal)}.
* Uses the default epsilon value of {@code 0.01} for zero comparisons.
*
* @param p1 the first polynomial.
* @param p2 the second polynomial.
*
* @return the product of the two input polynomials.
*/
public static Polynomial multiplyViaFFT(final Polynomial p1,
final Polynomial p2) {
return multiplyViaFFT(p1, p2, DEFAULT_SCALE, DEFAULT_EPSILON);
}
private static ComplexPolynomial
multiplyPointwise(final ComplexPolynomial a,
final ComplexPolynomial b) {
final ComplexPolynomial c = new ComplexPolynomial(a.length());
for (int i = 0; i < c.length(); i++) {
c.setCoefficient(
i,
a.getCoefficient(i)
.multiply(b.getCoefficient(i)));
}
return c;
}
private static ComplexPolynomial
computeFFT(final ComplexPolynomial complexPolynomial,
final int scale) {
final int n = complexPolynomial.length();
if (n == 1) {
return complexPolynomial.setScale(scale);
}
ComplexNumber omega = ComplexNumber.one();
final ComplexNumber root =
ComplexNumber
.getPrincipalRootOfUnity(n)
.setScale(scale);
final ComplexPolynomial[] a = complexPolynomial.split();
final ComplexPolynomial a0 = a[0];
final ComplexPolynomial a1 = a[1];
final ComplexPolynomial y0 = computeFFT(a0, scale);
final ComplexPolynomial y1 = computeFFT(a1, scale);
final ComplexPolynomial y = new ComplexPolynomial(n);
for (int k = 0; k < n / 2; k++) {
final ComplexNumber y0k = y0.getCoefficient(k);
final ComplexNumber y1k = y1.getCoefficient(k);
y.setCoefficient(k,
y0k.add(omega.multiply(y1k)).setScale(scale));
y.setCoefficient(k + n / 2,
y0k.substract(
omega.multiply(y1k))
.setScale(scale));
omega = omega.multiply(root).setScale(scale);
}
return y;
}
private static ComplexPolynomial
computeInverseFFT(ComplexPolynomial complexPolynomial,
final int scale) {
complexPolynomial = complexPolynomial.getConjugate();
complexPolynomial = computeFFT(complexPolynomial, scale);
complexPolynomial = complexPolynomial.getConjugate();
divide(complexPolynomial, complexPolynomial.length());
return complexPolynomial.setScale(scale);
}
private static void divide(final ComplexPolynomial cp,
final int n) {
final BigDecimal nn = BigDecimal.valueOf(n);
for (int i = 0; i < cp.length(); i++) {
cp.setCoefficient(i, cp.getCoefficient(i).divide(nn));
}
}
private static Polynomial multiplyViaKaratsubaImpl(final Polynomial p1,
final Polynomial p2) {
final int n = Math.max(p1.getDegree(),
p2.getDegree());
if (n == 0 || n == 1) {
return multiplyViaNaive(p1, p2);
}
final int m = (int) Math.ceil(n / 2.0);
final BigDecimal[] pPrime = new BigDecimal[m + 1];
final BigDecimal[] qPrime = new BigDecimal[m + 1];
for (int i = 0; i < m; i++) {
pPrime[i] = p1.getCoefficientInternal(i)
.add(p1.getCoefficientInternal(m + i));
qPrime[i] = p2.getCoefficientInternal(i)
.add(p2.getCoefficientInternal(m + i));
}
if (n > 2 * m - 1) {
pPrime[m] = p1.getCoefficientInternal(n);
qPrime[m] = p2.getCoefficientInternal(n);
} else {
pPrime[m] = BigDecimal.ZERO;
qPrime[m] = BigDecimal.ZERO;
}
Polynomial r1 = getR1Polynomial(p1, p2, m);
Polynomial r2 = getR2Polynomial(p1, p2, m, n);
Polynomial r3 = getR3Polynomial(pPrime,
qPrime);
final BigDecimal[] r4Coefficients = new BigDecimal[2 * m + 1];
for (int i = 0; i <= 2 * m; i++) {
r4Coefficients[i] = getCoefficient(r1,
r2,
r3,
i);
}
final Polynomial r4 = new Polynomial(r4Coefficients);
final BigDecimal[] rCoefficients = new BigDecimal[2 * n + 1];
for (int i = 0; i <= 2 * n; i++) {
rCoefficients[i] = getCoefficient(r1,
r2,
r4,
i,
m);
}
return new Polynomial(rCoefficients);
}
private static BigDecimal getCoefficient(final Polynomial r1,
final Polynomial r2,
final Polynomial r3,
final int i) {
final BigDecimal term1 = r1.getCoefficientInternal(i);
final BigDecimal term2 = r2.getCoefficientInternal(i);
final BigDecimal term3 = r3.getCoefficientInternal(i);
return term3.add(term1.negate()).add(term2.negate());
}
private static BigDecimal getCoefficient(final Polynomial r1,
final Polynomial r2,
final Polynomial r4,
final int i,
final int m) {
final BigDecimal term1 = r1.getCoefficientInternal(i);
final BigDecimal term4 = r4.getCoefficientInternal(i - m);
final BigDecimal term2 = r2.getCoefficientInternal(i - m * 2);
return term1.add(term2).add(term4);
}
private static Polynomial getR1Polynomial(final Polynomial p,
final Polynomial q,
final int m) {
final BigDecimal[] pCoefficients = p.toCoefficientArray(m);
final BigDecimal[] qCoefficients = q.toCoefficientArray(m);
final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
return multiplyViaKaratsubaImpl(pResultPolynomial,
qResultPolynomial);
}
private static Polynomial getR2Polynomial(final Polynomial p,
final Polynomial q,
final int m,
final int n) {
final BigDecimal[] pCoefficients = p.toCoefficientArray(m, n);
final BigDecimal[] qCoefficients = q.toCoefficientArray(m, n);
final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
return multiplyViaKaratsubaImpl(pResultPolynomial,
qResultPolynomial);
}
private static Polynomial getR3Polynomial(final BigDecimal[] pCoefficients,
final BigDecimal[] qCoefficients)
{
final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
return multiplyViaKaratsubaImpl(pResultPolynomial,
qResultPolynomial);
}
}
com.github.coderodde.math.ComplexNumber.java:
package com.github.coderodde.math;
import java.math.BigDecimal;
import java.math.RoundingMode;
/**
* This class implements basic facilities for dealing with complex number in the
* Fast Fourier Transform algorithms.
*
* @version 1.1.0 (Nov 24, 2024)
* @since 1.0.0 (Nov 23, 2024)
*/
public final class ComplexNumber {
/**
* The real part of this complex number.
*/
private final BigDecimal realPart;
/**
* The imaginary part of this complex number.
*/
private final BigDecimal imagPart;
public ComplexNumber(final BigDecimal realPart,
final BigDecimal imagPart) {
this.realPart = realPart;
this.imagPart = imagPart;
}
/**
* Constructs this complex number as the real unity.
*/
public ComplexNumber() {
this(BigDecimal.ONE,
BigDecimal.ZERO);
}
/**
* Constructs this complex number.
*
* @param realPart the real part.
* @param imagPart the imaginary part.
*/
public ComplexNumber(final double realPart,
final double imagPart) {
this(BigDecimal.valueOf(realPart),
BigDecimal.valueOf(imagPart));
}
/**
* Construct this complex number as a real number with value
* {@code realPart}.
*
* @param realPart the real part.
*/
public ComplexNumber(final BigDecimal realPart) {
this(realPart, BigDecimal.ZERO);
}
public ComplexNumber setScale(final int scale,
final RoundingMode roundingMode) {
return new ComplexNumber(realPart.setScale(scale, roundingMode),
imagPart.setScale(scale, roundingMode));
}
public ComplexNumber setScale(final int scale) {
return setScale(scale,
RoundingMode.HALF_UP);
}
public BigDecimal getRealPart() {
return realPart;
}
public BigDecimal getImaginaryPart() {
return imagPart;
}
public ComplexNumber add(final ComplexNumber other) {
return new ComplexNumber(realPart.add(other.realPart),
imagPart.add(other.imagPart));
}
public ComplexNumber negate() {
return new ComplexNumber(realPart.negate(),
imagPart.negate());
}
public ComplexNumber substract(final ComplexNumber other) {
return add(other.negate());
}
public ComplexNumber getConjugate() {
return new ComplexNumber(realPart, imagPart.negate());
}
/**
* Multiplies this complex number with the {@code other} complex number.
*
* @param other the second complex number to multiply.
*
* @return the complex product of this and {@code other} complex numbers.
*/
public ComplexNumber multiply(final ComplexNumber other) {
final BigDecimal resultRealPart =
realPart.multiply(other.realPart)
.subtract(imagPart.multiply(other.imagPart));
final BigDecimal resultImagPart =
realPart.multiply(other.imagPart)
.add(imagPart.multiply(other.realPart));
return new ComplexNumber(resultRealPart,
resultImagPart);
}
public ComplexNumber divide(final BigDecimal r) {
return new ComplexNumber(realPart.divide(r),
imagPart.divide(r));
}
@Override
public boolean equals(final Object o) {
if (o == this) {
return true;
}
if (o == null) {
return false;
}
if (!getClass().equals(o.getClass())) {
return false;
}
final ComplexNumber other = (ComplexNumber) o;
return realPart.equals(other.realPart) &&
imagPart.equals(other.imagPart);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
if (imagPart.equals(BigDecimal.ZERO)) {
return realPart.toString();
}
sb.append("(")
.append(realPart)
.append((imagPart.compareTo(BigDecimal.ZERO) > 0) ? " + " : " - ")
.append(imagPart)
.append("i)");
return sb.toString();
}
/**
* Returns the {@code n}th principal complex root of unity.
*
* @param n the root order.
*
* @return a principal complex root of unity.
*/
public static ComplexNumber getPrincipalRootOfUnity(final int n) {
final double u = 2.0 * Math.PI / (double) n;
final double re = Math.cos(u);
final double im = Math.sin(u);
return new ComplexNumber(re, im);
}
public static ComplexNumber one() {
return new ComplexNumber();
}
public static ComplexNumber zero() {
return new ComplexNumber(BigDecimal.ZERO);
}
}
com.github.coderodde.math.ComplexPolynomial.java:
package com.github.coderodde.math;
import static com.github.coderodde.math.Utils.powerToSuperscript;
import java.math.BigDecimal;
/**
* This class implements the polynomial over complex number. We need this class
* in the FFT algorithms.
*
* @version 1.0.0 (Nov 23, 2024)
* @since 1.0.0 (Nov 23, 2024)
*/
public final class ComplexPolynomial {
private final ComplexNumber[] coefficients;
public ComplexPolynomial(final int length) {
this.coefficients = new ComplexNumber[length];
}
public ComplexPolynomial(final Polynomial polynomial) {
this.coefficients = new ComplexNumber[polynomial.length()];
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = new ComplexNumber(polynomial.getCoefficient(i));
}
}
ComplexPolynomial(final ComplexNumber[] coefficients) {
this.coefficients = coefficients;
}
public ComplexPolynomial setScale(final int scale) {
final ComplexNumber[] coefficients = new ComplexNumber[length()];
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = this.coefficients[i].setScale(scale);
}
return new ComplexPolynomial(coefficients);
}
public ComplexPolynomial shrinkToHalf() {
final ComplexNumber[] coefficients = new ComplexNumber[length() / 2];
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = this.getCoefficient(i);
}
return new ComplexPolynomial(coefficients);
}
public ComplexPolynomial getConjugate() {
final ComplexNumber[] coefficients = new ComplexNumber[length()];
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = this.coefficients[i].getConjugate();
}
return new ComplexPolynomial(coefficients);
}
public int length() {
return coefficients.length;
}
public ComplexNumber getCoefficient(final int coefficientIndex) {
return coefficients[coefficientIndex];
}
public void setCoefficient(final int coefficientIndex,
final ComplexNumber coefficient) {
this.coefficients[coefficientIndex] = coefficient;
}
public ComplexPolynomial[] split() {
final int nextLength = coefficients.length / 2;
final ComplexPolynomial[] result = new ComplexPolynomial[2];
result[0] = new ComplexPolynomial(nextLength);
result[1] = new ComplexPolynomial(nextLength);
boolean readToFirst = true;
int coefficientIndex1 = 0;
int coefficientIndex2 = 0;
for (int i = 0; i < coefficients.length; i++) {
final ComplexNumber currentComplexNumber = coefficients[i];
if (readToFirst) {
readToFirst = false;
result[0].setCoefficient(coefficientIndex1++,
currentComplexNumber);
} else {
readToFirst = true;
result[1].setCoefficient(coefficientIndex2++,
currentComplexNumber);
}
}
return result;
}
/**
* Converts this complex polynomial to an ordinary polynomial by ignoring
* the imaginary parts in this complex polynomial and copying only the
* real part to the resultant polynomial.
*
* @return an ordinary polynomial.
*/
public Polynomial convertToPolynomial() {
final BigDecimal[] data = new BigDecimal[length()];
for (int i = 0; i < length(); i++) {
data[i] = coefficients[i].getRealPart();
}
return new Polynomial(data);
}
@Override
public String toString() {
if (length() == 0) {
return coefficients[0].toString();
}
final StringBuilder sb = new StringBuilder();
boolean first = true;
for (int pow = length() - 1; pow >= 0; pow--) {
if (first) {
first = false;
sb.append(getCoefficient(pow))
.append("x");
if (pow > 1) {
sb.append(powerToSuperscript(pow));
}
} else {
sb.append(" + ").append(getCoefficient(pow));
if (pow > 0) {
sb.append("x");
}
if (pow > 1) {
sb.append(powerToSuperscript(pow));
}
}
}
return sb.toString().replaceAll(",", ".");
}
}
Typical demo output
Na�ve: 7096 milliseconds.
Karatsuba: 3960 milliseconds.
FFT: 1866 milliseconds.
Na�ve agrees with Karatsuba and FFT: true.
Na�ve output: 63x???? + 87x???? - 3x???? - 35x???? - 6x???? - 154x???� - 162x???� + 28x??? - 1...
FFT output: 63x???? + 87x???? - 3x???? - 35x???? - 6x???? - 154x???� - 162x???� + 28x??? - 1...
Critique request
As always, I would like to receive any commentary you can come up with.