Skip to main content
Code Review

Return to Question

edited tags
Link
Source Link
coderodde
  • 31.6k
  • 15
  • 77
  • 201

Polynomial.java - multiplication algorithms: naïve, Karatsuba, FFT

Intro

This post is a supplement to Polynomial.java - a Java class for dealing with polynomials with BigDecimal coefficients. It presents the three polynomial multiplication algorithms. Given two polynomials: one of degree \$n_1\$, and another one of degree \$n_2\$, the multiplication algorithms are:

  • Naïve: running in \$\Theta(n_1 n_2)\$,
  • Karatsuba: running in \$\Theta(\max(n_1, n_2)^{1.58})\$,
  • FFT: running in \$\Theta(N \log N)\$, where \$N\$ is the closest (from above) power of \2ドル\$ to the \$\max(n_1, n_2)\$.

Code

com.github.coderodde.math.PolynomialMultiplier.java:

package com.github.coderodde.math;
import java.math.BigDecimal;
import java.util.Arrays;
/**
 * This class contains some polynomial multiplication algorithms.
 * 
 * @version 1.1.1 (Nov 24, 2024)
 * @since 1.0.0 (Nov 22, 2024)
 */
public final class PolynomialMultiplier {
 
 public static final BigDecimal DEFAULT_EPSILON = BigDecimal.valueOf(0.01);
 public static final int DEFAULT_SCALE = 6;
 
 /**
 * Multiplies {@code p1} and {@code p2} naively in {@code O(NM)} time.
 * 
 * @param p1 the first polynomial.
 * @param p2 the second polynomial.
 * 
 * @return the product of {@code p1} and {@code p2}.
 */
 public static Polynomial multiplyViaNaive(final Polynomial p1,
 final Polynomial p2) {
 final int coefficientArrayLength = p1.length()
 + p2.length()
 - 1;
 
 final BigDecimal[] coefficientArray = 
 new BigDecimal[coefficientArrayLength];
 
 // Initialize the result coefficient array:
 Arrays.fill(
 coefficientArray, 
 0, 
 coefficientArrayLength, 
 BigDecimal.ZERO);
 
 for (int index1 = 0;
 index1 < p1.length(); 
 index1++) {
 
 for (int index2 = 0; 
 index2 < p2.length(); 
 index2++) {
 
 final BigDecimal coefficient1 = p1.getCoefficient(index1);
 final BigDecimal coefficient2 = p2.getCoefficient(index2);
 
 coefficientArray[index1 + index2] = 
 coefficientArray[index1 + index2]
 .add(coefficient1.multiply(coefficient2));
 }
 }
 
 return new Polynomial(coefficientArray);
 }
 
 /**
 * Multiplies the two input polynomials in {@code O(N^(1.58)} time where 
 * {@code N} is the degree-bound of both {@code p1} and {@code p2}. This
 * implementation is adopted from 
 * <a href="https://www.cs.dartmouth.edu/~deepc/LecNotes/cs31/lec6.pdf">
 * these lecture slides</a>.
 * 
 * @param p1 the first polynomial.
 * @param p2 the second polynomial.
 * @param epsilon the epsilon value for detecting zero coefficients.
 * 
 * @return the product of the two input polynomials.
 */
 public static Polynomial multiplyViaKaratsuba(Polynomial p1,
 Polynomial p2,
 final BigDecimal epsilon) {
 
 final int n = Math.max(p1.length(), 
 p2.length());
 
 p1 = p1.setLength(n);
 p2 = p2.setLength(n);
 
 final Polynomial rawPolynomial = multiplyViaKaratsubaImpl(p1, p2);
 
 return rawPolynomial.minimizeDegree(epsilon);
 }
 
 /**
 * Delegates the multiplication 
 * {@link #multiplyViaKaratsuba(com.github.coderodde.math.Polynomial, com.github.coderodde.math.Polynomial, java.math.BigDecimal)}.
 * Uses the default epsilon value of {@code 0.01} for zero comparisons.
 * 
 * @param p1 the first polynomial.
 * @param p2 the second polynomial.
 * 
 * @return the product of the two input polynomials.
 */
 public static Polynomial multiplyViaKaratsuba(Polynomial p1,
 Polynomial p2) {
 return multiplyViaKaratsuba(p1, p2, DEFAULT_EPSILON);
 }
 
 /**
 * Multiplies the two input polynomials in {@code O(N log N} time where 
 * {@code N} is the degree-bound of both {@code p1} and {@code p2}. This
 * implementation is adopted from 
 * <a href="https://archive.org/details/introduction-to-algorithms-third-edition-2009/page/898/mode/1up?view=theater">
 * Introduction to Algorirthms, 3rd edition, Chapter 30</a>.
 * 
 * @param p1 the first polynomial.
 * @param p2 the second polynomial.
 * @param scale the scale to use for {@link java.math.BigDecimal} values.
 * @param epsilon the epsilon value for detecting zero coefficients.
 * 
 * @return the product of the two input polynomials.
 */
 public static Polynomial multiplyViaFFT(Polynomial p1,
 Polynomial p2,
 final int scale,
 final BigDecimal epsilon) {
 
 final int length = Math.max(p1.length(), 
 p2.length());
 
 final int n = Utils.getClosestUpwardPowerOfTwo(length) * 2;
 
 p1 = p1.setLength(n);
 p2 = p2.setLength(n);
 
 ComplexPolynomial a = computeFFT(new ComplexPolynomial(p1), scale);
 ComplexPolynomial b = computeFFT(new ComplexPolynomial(p2), scale);
 ComplexPolynomial c = multiplyPointwise(a, b);
 ComplexPolynomial r = computeInverseFFT(c, scale);
 
 return r.convertToPolynomial().minimizeDegree(epsilon);
 }
 
 /**
 * Delegates the multiplication 
 * {@link #multiplyViaFFT(com.github.coderodde.math.Polynomial, com.github.coderodde.math.Polynomial, java.math.BigDecimal)}.
 * Uses the default epsilon value of {@code 0.01} for zero comparisons.
 * 
 * @param p1 the first polynomial.
 * @param p2 the second polynomial.
 * 
 * @return the product of the two input polynomials.
 */
 public static Polynomial multiplyViaFFT(final Polynomial p1,
 final Polynomial p2) {
 
 return multiplyViaFFT(p1, p2, DEFAULT_SCALE, DEFAULT_EPSILON);
 }
 
 private static ComplexPolynomial 
 multiplyPointwise(final ComplexPolynomial a,
 final ComplexPolynomial b) {
 
 final ComplexPolynomial c = new ComplexPolynomial(a.length());
 
 for (int i = 0; i < c.length(); i++) {
 c.setCoefficient(
 i, 
 a.getCoefficient(i)
 .multiply(b.getCoefficient(i)));
 }
 
 return c;
 }
 
 private static ComplexPolynomial 
 computeFFT(final ComplexPolynomial complexPolynomial,
 final int scale) {
 
 final int n = complexPolynomial.length();
 
 if (n == 1) {
 return complexPolynomial.setScale(scale);
 }
 
 ComplexNumber omega = ComplexNumber.one(); 
 
 final ComplexNumber root = 
 ComplexNumber
 .getPrincipalRootOfUnity(n)
 .setScale(scale);
 
 final ComplexPolynomial[] a = complexPolynomial.split();
 final ComplexPolynomial a0 = a[0];
 final ComplexPolynomial a1 = a[1];
 
 final ComplexPolynomial y0 = computeFFT(a0, scale);
 final ComplexPolynomial y1 = computeFFT(a1, scale);
 final ComplexPolynomial y = new ComplexPolynomial(n);
 
 for (int k = 0; k < n / 2; k++) {
 final ComplexNumber y0k = y0.getCoefficient(k);
 final ComplexNumber y1k = y1.getCoefficient(k);
 
 y.setCoefficient(k, 
 y0k.add(omega.multiply(y1k)).setScale(scale));
 
 y.setCoefficient(k + n / 2, 
 y0k.substract(
 omega.multiply(y1k))
 .setScale(scale));
 
 omega = omega.multiply(root).setScale(scale);
 }
 
 return y;
 }
 
 private static ComplexPolynomial 
 computeInverseFFT(ComplexPolynomial complexPolynomial,
 final int scale) {
 
 complexPolynomial = complexPolynomial.getConjugate();
 complexPolynomial = computeFFT(complexPolynomial, scale);
 complexPolynomial = complexPolynomial.getConjugate();
 divide(complexPolynomial, complexPolynomial.length());
 return complexPolynomial.setScale(scale);
 }
 
 private static void divide(final ComplexPolynomial cp, 
 final int n) {
 
 final BigDecimal nn = BigDecimal.valueOf(n);
 
 for (int i = 0; i < cp.length(); i++) {
 cp.setCoefficient(i, cp.getCoefficient(i).divide(nn));
 }
 }
 
 private static Polynomial multiplyViaKaratsubaImpl(final Polynomial p1,
 final Polynomial p2) {
 final int n = Math.max(p1.getDegree(),
 p2.getDegree());
 
 if (n == 0 || n == 1) {
 return multiplyViaNaive(p1, p2);
 }
 
 final int m = (int) Math.ceil(n / 2.0);
 
 final BigDecimal[] pPrime = new BigDecimal[m + 1];
 final BigDecimal[] qPrime = new BigDecimal[m + 1];
 
 for (int i = 0; i < m; i++) {
 pPrime[i] = p1.getCoefficientInternal(i)
 .add(p1.getCoefficientInternal(m + i));
 
 qPrime[i] = p2.getCoefficientInternal(i)
 .add(p2.getCoefficientInternal(m + i));
 }
 
 if (n > 2 * m - 1) {
 pPrime[m] = p1.getCoefficientInternal(n);
 qPrime[m] = p2.getCoefficientInternal(n);
 } else {
 pPrime[m] = BigDecimal.ZERO;
 qPrime[m] = BigDecimal.ZERO;
 }
 
 Polynomial r1 = getR1Polynomial(p1, p2, m);
 Polynomial r2 = getR2Polynomial(p1, p2, m, n);
 Polynomial r3 = getR3Polynomial(pPrime, 
 qPrime);
 
 final BigDecimal[] r4Coefficients = new BigDecimal[2 * m + 1];
 
 for (int i = 0; i <= 2 * m; i++) {
 r4Coefficients[i] = getCoefficient(r1, 
 r2, 
 r3, 
 i);
 }
 
 final Polynomial r4 = new Polynomial(r4Coefficients);
 
 final BigDecimal[] rCoefficients = new BigDecimal[2 * n + 1];
 
 for (int i = 0; i <= 2 * n; i++) {
 rCoefficients[i] = getCoefficient(r1,
 r2,
 r4,
 i,
 m);
 }
 
 return new Polynomial(rCoefficients);
 }
 
 private static BigDecimal getCoefficient(final Polynomial r1,
 final Polynomial r2,
 final Polynomial r3,
 final int i) {
 
 final BigDecimal term1 = r1.getCoefficientInternal(i);
 final BigDecimal term2 = r2.getCoefficientInternal(i);
 final BigDecimal term3 = r3.getCoefficientInternal(i);
 
 return term3.add(term1.negate()).add(term2.negate());
 }
 
 private static BigDecimal getCoefficient(final Polynomial r1, 
 final Polynomial r2,
 final Polynomial r4,
 final int i,
 final int m) {
 
 final BigDecimal term1 = r1.getCoefficientInternal(i);
 final BigDecimal term4 = r4.getCoefficientInternal(i - m);
 final BigDecimal term2 = r2.getCoefficientInternal(i - m * 2);
 
 return term1.add(term2).add(term4);
 }
 
 private static Polynomial getR1Polynomial(final Polynomial p,
 final Polynomial q,
 final int m) {
 
 final BigDecimal[] pCoefficients = p.toCoefficientArray(m);
 final BigDecimal[] qCoefficients = q.toCoefficientArray(m);
 
 final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
 final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
 
 return multiplyViaKaratsubaImpl(pResultPolynomial,
 qResultPolynomial);
 }
 
 private static Polynomial getR2Polynomial(final Polynomial p,
 final Polynomial q,
 final int m,
 final int n) {
 
 final BigDecimal[] pCoefficients = p.toCoefficientArray(m, n);
 final BigDecimal[] qCoefficients = q.toCoefficientArray(m, n);
 
 final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
 final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
 
 return multiplyViaKaratsubaImpl(pResultPolynomial, 
 qResultPolynomial);
 }
 
 private static Polynomial getR3Polynomial(final BigDecimal[] pCoefficients,
 final BigDecimal[] qCoefficients) 
 {
 final Polynomial pResultPolynomial = new Polynomial(pCoefficients);
 final Polynomial qResultPolynomial = new Polynomial(qCoefficients);
 
 return multiplyViaKaratsubaImpl(pResultPolynomial,
 qResultPolynomial);
 }
}

com.github.coderodde.math.ComplexNumber.java:

package com.github.coderodde.math;
import java.math.BigDecimal;
import java.math.RoundingMode;
/**
 * This class implements basic facilities for dealing with complex number in the
 * Fast Fourier Transform algorithms.
 * 
 * @version 1.1.0 (Nov 24, 2024)
 * @since 1.0.0 (Nov 23, 2024)
 */
public final class ComplexNumber {
 
 /**
 * The real part of this complex number.
 */
 private final BigDecimal realPart;
 
 /**
 * The imaginary part of this complex number.
 */
 private final BigDecimal imagPart;
 
 public ComplexNumber(final BigDecimal realPart,
 final BigDecimal imagPart) {
 
 this.realPart = realPart;
 this.imagPart = imagPart;
 }
 
 /**
 * Constructs this complex number as the real unity.
 */
 public ComplexNumber() {
 this(BigDecimal.ONE,
 BigDecimal.ZERO);
 }
 
 /**
 * Constructs this complex number.
 * 
 * @param realPart the real part.
 * @param imagPart the imaginary part.
 */
 public ComplexNumber(final double realPart,
 final double imagPart) {
 
 this(BigDecimal.valueOf(realPart),
 BigDecimal.valueOf(imagPart));
 }
 
 /**
 * Construct this complex number as a real number with value 
 * {@code realPart}.
 * 
 * @param realPart the real part.
 */
 public ComplexNumber(final BigDecimal realPart) {
 this(realPart, BigDecimal.ZERO);
 }
 
 public ComplexNumber setScale(final int scale,
 final RoundingMode roundingMode) {
 
 return new ComplexNumber(realPart.setScale(scale, roundingMode),
 imagPart.setScale(scale, roundingMode));
 }
 
 public ComplexNumber setScale(final int scale) {
 return setScale(scale, 
 RoundingMode.HALF_UP);
 }
 
 public BigDecimal getRealPart() {
 return realPart;
 }
 
 public BigDecimal getImaginaryPart() {
 return imagPart;
 }
 
 public ComplexNumber add(final ComplexNumber other) {
 return new ComplexNumber(realPart.add(other.realPart),
 imagPart.add(other.imagPart));
 }
 
 public ComplexNumber negate() {
 return new ComplexNumber(realPart.negate(),
 imagPart.negate());
 }
 
 public ComplexNumber substract(final ComplexNumber other) {
 return add(other.negate());
 }
 
 public ComplexNumber getConjugate() {
 return new ComplexNumber(realPart, imagPart.negate());
 }
 
 /**
 * Multiplies this complex number with the {@code other} complex number.
 * 
 * @param other the second complex number to multiply.
 * 
 * @return the complex product of this and {@code other} complex numbers. 
 */
 public ComplexNumber multiply(final ComplexNumber other) {
 final BigDecimal resultRealPart = 
 realPart.multiply(other.realPart)
 .subtract(imagPart.multiply(other.imagPart));
 
 final BigDecimal resultImagPart = 
 realPart.multiply(other.imagPart)
 .add(imagPart.multiply(other.realPart));
 
 return new ComplexNumber(resultRealPart,
 resultImagPart);
 }
 
 public ComplexNumber divide(final BigDecimal r) {
 return new ComplexNumber(realPart.divide(r),
 imagPart.divide(r));
 }
 
 @Override
 public boolean equals(final Object o) {
 if (o == this) {
 return true;
 }
 
 if (o == null) {
 return false;
 }
 
 if (!getClass().equals(o.getClass())) {
 return false;
 }
 
 final ComplexNumber other = (ComplexNumber) o;
 
 return realPart.equals(other.realPart) &&
 imagPart.equals(other.imagPart);
 }
 
 @Override
 public String toString() {
 final StringBuilder sb = new StringBuilder();
 
 if (imagPart.equals(BigDecimal.ZERO)) {
 return realPart.toString();
 }
 
 sb.append("(")
 .append(realPart)
 .append((imagPart.compareTo(BigDecimal.ZERO) > 0) ? " + " : " - ")
 .append(imagPart)
 .append("i)");
 
 return sb.toString();
 }
 
 /**
 * Returns the {@code n}th principal complex root of unity.
 * 
 * @param n the root order.
 * 
 * @return a principal complex root of unity.
 */
 public static ComplexNumber getPrincipalRootOfUnity(final int n) {
 
 final double u = 2.0 * Math.PI / (double) n;
 final double re = Math.cos(u);
 final double im = Math.sin(u);
 
 return new ComplexNumber(re, im);
 }
 
 public static ComplexNumber one() {
 return new ComplexNumber();
 }
 
 public static ComplexNumber zero() {
 return new ComplexNumber(BigDecimal.ZERO);
 }
}

com.github.coderodde.math.ComplexPolynomial.java:

package com.github.coderodde.math;
import static com.github.coderodde.math.Utils.powerToSuperscript;
import java.math.BigDecimal;
/**
 * This class implements the polynomial over complex number. We need this class
 * in the FFT algorithms.
 * 
 * @version 1.0.0 (Nov 23, 2024)
 * @since 1.0.0 (Nov 23, 2024)
 */
public final class ComplexPolynomial {
 
 private final ComplexNumber[] coefficients;
 
 public ComplexPolynomial(final int length) {
 this.coefficients = new ComplexNumber[length];
 }
 
 public ComplexPolynomial(final Polynomial polynomial) {
 this.coefficients = new ComplexNumber[polynomial.length()];
 
 for (int i = 0; i < coefficients.length; i++) {
 coefficients[i] = new ComplexNumber(polynomial.getCoefficient(i));
 }
 }
 
 ComplexPolynomial(final ComplexNumber[] coefficients) {
 this.coefficients = coefficients;
 }
 
 public ComplexPolynomial setScale(final int scale) {
 final ComplexNumber[] coefficients = new ComplexNumber[length()];
 
 for (int i = 0; i < coefficients.length; i++) {
 coefficients[i] = this.coefficients[i].setScale(scale);
 }
 
 return new ComplexPolynomial(coefficients);
 }
 
 public ComplexPolynomial shrinkToHalf() {
 final ComplexNumber[] coefficients = new ComplexNumber[length() / 2];
 
 for (int i = 0; i < coefficients.length; i++) {
 coefficients[i] = this.getCoefficient(i);
 }
 
 return new ComplexPolynomial(coefficients);
 }
 
 public ComplexPolynomial getConjugate() {
 final ComplexNumber[] coefficients = new ComplexNumber[length()];
 
 for (int i = 0; i < coefficients.length; i++) {
 coefficients[i] = this.coefficients[i].getConjugate();
 }
 
 return new ComplexPolynomial(coefficients);
 }
 
 public int length() {
 return coefficients.length;
 }
 
 public ComplexNumber getCoefficient(final int coefficientIndex) {
 return coefficients[coefficientIndex];
 }
 
 public void setCoefficient(final int coefficientIndex, 
 final ComplexNumber coefficient) {
 
 this.coefficients[coefficientIndex] = coefficient;
 }
 
 public ComplexPolynomial[] split() {
 final int nextLength = coefficients.length / 2;
 final ComplexPolynomial[] result = new ComplexPolynomial[2];
 
 result[0] = new ComplexPolynomial(nextLength);
 result[1] = new ComplexPolynomial(nextLength);
 
 boolean readToFirst = true;
 int coefficientIndex1 = 0;
 int coefficientIndex2 = 0;
 
 for (int i = 0; i < coefficients.length; i++) {
 final ComplexNumber currentComplexNumber = coefficients[i];
 
 if (readToFirst) {
 readToFirst = false;
 result[0].setCoefficient(coefficientIndex1++,
 currentComplexNumber);
 } else {
 readToFirst = true;
 result[1].setCoefficient(coefficientIndex2++, 
 currentComplexNumber);
 }
 }
 
 return result;
 }
 
 /**
 * Converts this complex polynomial to an ordinary polynomial by ignoring 
 * the imaginary parts in this complex polynomial and copying only the 
 * real part to the resultant polynomial.
 * 
 * @return an ordinary polynomial.
 */
 public Polynomial convertToPolynomial() {
 final BigDecimal[] data = new BigDecimal[length()];
 
 for (int i = 0; i < length(); i++) {
 data[i] = coefficients[i].getRealPart();
 }
 
 return new Polynomial(data);
 }
 
 @Override
 public String toString() {
 if (length() == 0) {
 return coefficients[0].toString();
 }
 final StringBuilder sb = new StringBuilder();
 
 boolean first = true;
 
 for (int pow = length() - 1; pow >= 0; pow--) {
 if (first) {
 first = false;
 sb.append(getCoefficient(pow))
 .append("x");
 
 if (pow > 1) {
 sb.append(powerToSuperscript(pow));
 }
 } else {
 sb.append(" + ").append(getCoefficient(pow));
 
 if (pow > 0) {
 sb.append("x");
 }
 
 if (pow > 1) {
 sb.append(powerToSuperscript(pow));
 }
 }
 }
 
 return sb.toString().replaceAll(",", ".");
 }
}

Typical demo output

Na�ve: 7096 milliseconds.
Karatsuba: 3960 milliseconds.
FFT: 1866 milliseconds.
Na�ve agrees with Karatsuba and FFT: true.
Na�ve output: 63x???? + 87x???? - 3x???? - 35x???? - 6x???? - 154x???� - 162x???� + 28x??? - 1...
FFT output: 63x???? + 87x???? - 3x???? - 35x???? - 6x???? - 154x???� - 162x???� + 28x??? - 1...

Critique request

As always, I would like to receive any commentary you can come up with.

lang-java

AltStyle によって変換されたページ (->オリジナル) /