Intro
This time, I have produced Polynomial.java. It is a simple polynomial class that stores its coefficients as double
values in an array.
Code
com.github.coderodde.math.Polynomial.java:
package com.github.coderodde.math;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
/**
* This class implements a polynomial.
*
* @version 1.0.0 (Nov 21, 2024)
* @since 1.0.0 (Nov 21, 2024)
*/
public final class Polynomial {
/**
* The actual array storing coefficients. {@code coefficients[i]} is the
* coefficient of the term with degree {@code i}, so the constant term
* appears at {@code coefficients[0]}.
*/
private final double[] coefficients;
/**
* Constructs this polynomial from the input coefficients. The lowest degree
* coefficients is in {@code coefficients[0]} and so on.
*
* @param coefficients the array of coefficients. Each array component must
* be a real number.
*/
public Polynomial(double... coefficients) {
int i;
for (i = coefficients.length - 1; i >= 0; i--) {
if (coefficients[i] != 0.0) {
break;
}
}
if (i == -1) {
// Special case: a "null" polynomial, we convert it to y = 0.
this.coefficients = new double[]{ 0.0 };
} else {
this.coefficients = new double[i + 1];
System.arraycopy(coefficients,
0,
this.coefficients,
0,
this.coefficients.length);
validateCoefficients();
}
}
/**
* Evaluates this polynomial at the point {@code x}.
*
* @param x the argument value for this polynomial.
*
* @return the value of this polynomial at the specified {@code x}
* coordinate.
*/
public double evaluate(final double x) {
validateX(x);
double value = 0.0;
for (int pow = 0; pow < coefficients.length; pow++) {
value += coefficients[pow] * Math.pow(x, pow);
}
return value;
}
/**
* Gets the {@code coefficientIndex}th coefficient.
*
* @param coefficientIndex the index of the target coefficient.
*
* @return the target coefficient.
*/
public double getCoefficient(final int coefficientIndex) {
try {
return coefficients[coefficientIndex];
} catch (final ArrayIndexOutOfBoundsException ex) {
final String exceptionMessage =
String.format(
"coefficientIndex[%d] is out of " +
"valid bounds [0, %d)",
coefficientIndex,
coefficients.length);
throw new IllegalArgumentException(exceptionMessage, ex);
}
}
/**
* Gets the number of coefficients in this polynomial.
*
* @return the number of coefficients.
*/
public int length() {
return coefficients.length;
}
/**
* Constructs and returns an instance of {@link Polynomial} that is the
* summation of {@code this} polynomial and {@code othr}.
*
* @param othr the second polynomial to sum.
*
* @return the sum of {@code this} and {@code othr} polynomials.
*/
public Polynomial sum(final Polynomial othr) {
final int thisPolynomialLength = this.length();
final int othrPolynomialLength = othr.length();
final int longestPolynomialLength = Math.max(thisPolynomialLength,
othrPolynomialLength);
final double[] sumCoefficients = new double[longestPolynomialLength];
final Polynomial shrtPolynomial;
final Polynomial longPolynomial;
if (thisPolynomialLength <= othrPolynomialLength) {
shrtPolynomial = this;
longPolynomial = othr;
} else {
shrtPolynomial = othr;
longPolynomial = this;
}
int coefficientIndex = 0;
for (; coefficientIndex < shrtPolynomial.length();
coefficientIndex++) {
sumCoefficients[coefficientIndex] +=
shrtPolynomial.getCoefficient(coefficientIndex) +
longPolynomial.getCoefficient(coefficientIndex);
}
for (; coefficientIndex < longPolynomial.length();
coefficientIndex++) {
sumCoefficients[coefficientIndex] =
longPolynomial.getCoefficient(coefficientIndex);
}
return new Polynomial(sumCoefficients);
}
/**
* Returns the degree of this polynomial.
*
* @return the degree of this polynomial.
*/
public int getDegree() {
return coefficients.length - 1;
}
/**
* Constructs and returns an instance of {@link Polynomial} that is the
* multiplication of {@code this} and {@code othr} polynomials.
*
* @param othr the second polynomial to multiply.
*
* @return the product of this and the input polynomials.
*/
public Polynomial multiply(final Polynomial othr) {
final int coefficientArrayLength = this.length()
+ othr.length()
- 1;
final double[] coefficientArray = new double[coefficientArrayLength];
for (int index1 = 0;
index1 < this.length();
index1++) {
for (int index2 = 0;
index2 < othr.length();
index2++) {
final double coefficient1 = this.getCoefficient(index1);
final double coefficient2 = othr.getCoefficient(index2);
coefficientArray[index1 + index2] += coefficient1
* coefficient2;
}
}
return new Polynomial(coefficientArray);
}
private void validateCoefficients() {
for (int i = 0; i < coefficients.length; i++) {
validateCoefficient(i);
}
}
private void validateCoefficient(final int coefficientIndex) {
final double coefficient = coefficients[coefficientIndex];
if (Double.isNaN(coefficient)) {
final String exceptionMessage =
String.format("The coefficients[%d] is NaN.",
coefficientIndex);
throw new IllegalArgumentException(exceptionMessage);
}
if (Double.isInfinite(coefficient)) {
final String exceptionMessage =
String.format("The coefficient[%d] is infinite: %f",
coefficientIndex,
coefficient);
throw new IllegalArgumentException(exceptionMessage);
}
}
@Override
public boolean equals(final Object o) {
if (o == this) {
return true;
}
if (o == null) {
return false;
}
if (!getClass().equals(o.getClass())) {
return false;
}
final Polynomial othr = (Polynomial) o;
return Arrays.equals(this.coefficients,
othr.coefficients);
}
@Override
public int hashCode() {
// Generated by NetBeans 23:
int hash = 3;
hash = 43 * hash + Arrays.hashCode(coefficients);
return hash;
}
@Override
public String toString() {
if (getDegree() == 0) {
return String.format("%f", coefficients[0]).replace(",", ".");
}
final StringBuilder sb = new StringBuilder();
boolean first = true;
for (int pow = getDegree(); pow >= 0; pow--) {
if (first) {
first = false;
sb.append(getCoefficient(pow))
.append("x^")
.append(pow);
} else {
final double coefficient = getCoefficient(pow);
if (coefficient > 0.0) {
sb.append(" + ")
.append(coefficient);
} else if (coefficient < 0.0) {
sb.append(" - ")
.append(Math.abs(coefficient));
} else {
// Once here, there is no term with degree pow:
continue;
}
if (pow > 0) {
sb.append("x^")
.append(pow);
}
}
}
return sb.toString().replaceAll(",", ".");
}
public static Builder getPolynomialBuilder() {
return new Builder();
}
public static final class Builder {
private final Map<Integer, Double> map = new HashMap<>();
private int maximumCoefficientIndex = 0;
public Builder add(final int coefficientIndex,
final double coefficient) {
this.maximumCoefficientIndex =
Math.max(this.maximumCoefficientIndex,
coefficientIndex);
map.put(coefficientIndex,
coefficient);
return this;
}
public Polynomial build() {
final double[] coefficients =
new double[maximumCoefficientIndex + 1];
for (final Map.Entry<Integer, Double> e : map.entrySet()) {
coefficients[e.getKey()] = e.getValue();
}
return new Polynomial(coefficients);
}
}
private void validateX(final double x) {
if (Double.isNaN(x)) {
throw new IllegalArgumentException("x is NaN.");
}
if (Double.isInfinite(x)) {
final String exceptionMessage =
String.format("x is infinite: %f", x);
throw new IllegalArgumentException(exceptionMessage);
}
}
}
com.github.coderodde.math.PolynomialTest.java:
package com.github.coderodde.math;
import org.junit.Test;
import static org.junit.Assert.*;
public final class PolynomialTest {
private static final double E = 1E-3;
@Test
public void evaluate1() {
Polynomial p1 = new Polynomial(-1.0, 2.0); // 2x - 1
double value = p1.evaluate(3.0);
assertEquals(5.0, value, E);
}
@Test
public void evaluate2() {
Polynomial p1 = new Polynomial(-5.0, 3.0, 2.0); // 2x^2 + 3x - 5
double value = p1.evaluate(4.0);
assertEquals(39.0, value, E);
}
@Test
public void getCoefficient() {
Polynomial p = new Polynomial(1, 2, 3, 4);
assertEquals(1, p.getCoefficient(0), E);
assertEquals(2, p.getCoefficient(1), E);
assertEquals(3, p.getCoefficient(2), E);
assertEquals(4, p.getCoefficient(3), E);
}
@Test
public void length() {
assertEquals(5, new Polynomial(3, -2, -1, 4, 2).length());
}
@Test
public void sum() {
Polynomial p1 = new Polynomial(3, -1, 2);
Polynomial p2 = new Polynomial(5, 4);
Polynomial sum = p1.sum(p2);
Polynomial expected = new Polynomial(8, 3, 2);
assertEquals(expected, sum);
}
@Test
public void getDegree() {
assertEquals(3, new Polynomial(1, -2, 3, -4).getDegree());
}
@Test
public void multiply() {
Polynomial p1 = new Polynomial(3, -2, 1); // x^2 - 2x + 3
Polynomial p2 = new Polynomial(4, 1); // x + 4
// (x^3 - 2x^2 + 3x) + (4x^2 - 8x + 12) = x^3 + 2x^2 - 5x + 12
Polynomial product = p1.multiply(p2);
assertEquals(3, product.getDegree());
Polynomial expected = new Polynomial(12, -5, 2, 1);
assertEquals(expected, product);
}
@Test
public void testConstructEmptyPolynomial() {
Polynomial p = new Polynomial();
assertEquals(1, p.length());
assertEquals(0, p.getDegree());
assertEquals(0, p.getCoefficient(0), E);
assertEquals(0, p.evaluate(4), E);
assertEquals(0, p.evaluate(-3), E);
}
@Test
public void builder() {
final Polynomial p = Polynomial.getPolynomialBuilder()
.add(10, 10)
.add(5000, 5000)
.build();
assertEquals(5000, p.getDegree());
assertEquals(10, p.getCoefficient(10), E);
assertEquals(5000, p.getCoefficient(5000), E);
}
@Test
public void emptyPolynomial() {
Polynomial p = new Polynomial();
assertEquals(1, p.length());
assertEquals(0, p.getDegree());
assertEquals(0.0, p.getCoefficient(0), E);
assertEquals(0.0, p.evaluate(10.0), E);
}
}
com.github.coderodde.math.demos.PolynomialDemo.java:
package com.github.coderodde.math.demos;
import com.github.coderodde.math.Polynomial;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
/**
* This is demonstration program for the {@link Polynomial} class.
*
* @version 1.0.0 (Nov 21, 2024)
* @since 1.0.0 (Nov 21, 2024)
*/
public final class PolynomialDemo {
private static final Map<String, Polynomial> environment = new HashMap<>();
private static Polynomial previousPolynomial = null;
public static void main(final String[] args) {
final Scanner scanner = new Scanner(System.in);
while (true) {
System.out.print("> ");
final String line = scanner.nextLine().trim().toLowerCase();
if (line.equals("quit") || line.equals("exit")) {
return;
}
try {
if (line.contains("save")) {
saveImpl(line);
} else if (line.contains("print")) {
print(line);
} else if (line.contains("+")) {
sumImpl(line);
} else if (line.contains("*")) {
productImpl(line);
} else {
parsePolynomial(line);
}
} catch (final Exception ex) {
System.err.printf(">>> Exception: %s!", ex.getMessage());
System.out.println();
}
}
}
private static void saveImpl(final String line) {
final String[] parts = line.split("\\s+");
final String polynommialName = parts[1];
environment.put(polynommialName,
previousPolynomial);
}
private static void print(final String line) {
System.out.println(environment.get(line.split("\\s+")[1]));
}
private static void sumImpl(final String line) {
final String[] parts = line.split("\\+");
if (parts.length != 2) {
final String exceptionMessage =
String.format("Invalid line: \"%s\"", line);
throw new IllegalArgumentException(exceptionMessage);
}
final String polynomialName1 = parts[0].trim();
final String polynomialName2 = parts[1].trim();
final Polynomial polynomial1 = environment.get(polynomialName1);
final Polynomial polynomial2 = environment.get(polynomialName2);
final Polynomial sum = polynomial1.sum(polynomial2);
previousPolynomial = sum;
System.out.println(sum);
}
private static void productImpl(final String line) {
final String[] parts = line.split("\\*");
if (parts.length != 2) {
final String exceptionMessage =
String.format("Invalid line: \"%s\"", line);
throw new IllegalArgumentException(exceptionMessage);
}
final String polynomialName1 = parts[0].trim();
final String polynomialName2 = parts[1].trim();
final Polynomial polynomial1 = environment.get(polynomialName1);
final Polynomial polynomial2 = environment.get(polynomialName2);
final Polynomial product = polynomial1.multiply(polynomial2);
previousPolynomial = product;
System.out.println(product);
}
private static Polynomial parsePolynomial(final String line) {
final String[] termsStrings = line.split("\\s+");
final Polynomial.Builder builder = Polynomial.getPolynomialBuilder();
int coefficientIndex = 0;
for (String termString : termsStrings) {
termString = termString.trim();
final double coefficient;
try {
coefficient = Double.parseDouble(termString);
} catch (final NumberFormatException ex) {
final String exceptionMessage =
String.format(
"coefficient \"%s\" is not a real number",
termString);
throw new IllegalArgumentException(exceptionMessage);
}
builder.add(coefficientIndex++,
coefficient);
}
final Polynomial p = builder.build();
previousPolynomial = p;
System.out.printf(">>> %s\n", p);
return p;
}
}
Typical REPL
> 4 3 -1
>>> -1.0x^2 + 3.0x^1 + 4.0
> save a
> 7 -2 0 1
>>> 1.0x^3 - 2.0x^1 + 7.0
> save b
> print a
-1.0x^2 + 3.0x^1 + 4.0
> print b
1.0x^3 - 2.0x^1 + 7.0
> a + b
1.0x^3 - 1.0x^2 + 1.0x^1 + 11.0
> a * b
-1.0x^5 + 3.0x^4 + 6.0x^3 - 13.0x^2 + 13.0x^1 + 28.0
>
Critique request
As always, I am eager to hear any commentary on how to improve my routine.
2 Answers 2
Polynomial evaluation
for (int pow = 0; pow < coefficients.length; pow++) { value += coefficients[pow] * Math.pow(x, pow); }
It's fine, it's a naive algorithm though. There are various alternatives, these are just some basic methods:
- build up powers of
x
iteratively, costing 2 multiplications and an addition per iteration (more or less replacing the cost ofMath.pow
by the cost of a multiplication, but that chain of multiplications also forms a new loop-carried dependency with non-trivial latency so it's not that simple to reason about the difference in cost) - Horner's method starts from the end and uses only 1 multiplication and an addition per iteration, but with both of them in the loop-carried dependency.
- Estrin's method cuts up that loop-carried dependency into more of a tree structure but it's more complicated.
If your primary concern is accuracy, I'm not up to date on what the good options are, but it's not just evaluating monomials and summing them in their original order (or at least not with an uncompensated sum, which can easily lose track of digits that later turn out to be significant due to a cancellation).
Polynomial multiplication
It's the classic O(n2) algorithm, which is fine especially for small polynomials but you can do better for polynomials that aren't small, for example Karatsuba multiplication and FFT-based multiplication. Of course, those are more complicated.
Possible extensions
There are some extra operations that may be useful. For example:
- A fairly typical thing to want to do with a polynomial is find its roots. You could implement Jenkins–Traub or something like that, that's an old algorithm and a bit complicated, I'm not up to date on the current state of the art.
- Another fairly typical thing to want to do with a polynomials is getting its derivative.
- Polynomials can be divided (with remainder), but I see that more in the context of non-floating-point coefficients.
- There are algorithms that evaluate a polynomial at multiple points in a batch more efficiently than evaluating the polynomial at each point individually.
Other
I'm not immediately convinced that a polynomial with floating point coefficients should have an equality test and hash code. I'm not saying that it definitely shouldn't have them, but I would only add them if there is a specific need for them, otherwise they may create more problems (the good old "why are these two things not considered equal even though they are", you know the deal) than they solve.
Precision loss
Binary floating point numbers due to their binary nature are prone to rounding errors. There's a possibility of precision loss while performing calculations with double
values.
To avoid that, you can use BigDecimal
to represent coefficients instead of double
.
Naming
Map<Integer, Double> map
- yes, that's a map, no doubt about it. But it gives no clue about the purpose of this field.
The names of code elements should communicate their responsibilities to the code reader, not their types. coefficientByIndex
or indexToCoefficient
are better options.
shrtPolynomial
, othr
- "Never abbrev" (a joke from Martin Fowler). Please, don't skimp on characters. Using shrt
instead of short
doesn’t improve code readability, on the contrary, it requires a little bit more brainpower.
Explore related questions
See similar questions with these tags.
a * b
→-1.0x5 + 3.0x4 + 6.0x³ - 13.0x² + 13.0x¹ + 28.0
). And perhaps consider dropping the ¹ superscript, which is not normally written. \$\endgroup\$