(See the previous iteration.)
My two previous methods for computing the integer square root of a number \$N\$ ran in the \$\mathcal{O}(\sqrt{N})\$ worst case time. Now I have added a method (intSqrt3
) that runs in \$\mathcal{O}(\log \sqrt{N})\$ time:
Main.java:
import java.util.Random;
import java.util.function.Function;
public class Main {
public static long intSqrt1(long number) {
long sqrt = 0L;
while ((sqrt + 1) * (sqrt + 1) <= number) {
sqrt++;
}
return sqrt;
}
public static long intSqrt2(long number) {
if (number <= 0L) {
return 0L;
}
long sqrt = 1L;
while (4 * sqrt * sqrt <= number) {
sqrt *= 2;
}
while ((sqrt + 1) * (sqrt + 1) <= number) {
sqrt++;
}
return sqrt;
}
public static long intSqrt3(long number) {
if (number <= 0L) {
return 0L;
}
long sqrt = 1L;
// Do the exponential search.
while (4 * sqrt * sqrt <= number) {
sqrt *= 2;
}
long left = sqrt;
long right = 2 * sqrt;
long middle = 0;
// Do the binary search over the range that is guaranteed to contain
// the integer square root.
while (left < right) {
middle = left + (right - left) / 2;
if (middle * middle < number) {
left = middle + 1;
} else if (middle * middle > number) {
right = middle - 1;
} else {
return middle;
}
}
// Correct the binary search "noise". This iterates no more than 3
// times.
long ret = middle + 1;
while (ret * ret > number) {
--ret;
}
return ret;
}
public static long intSqrt4(long number) {
return (long) Math.sqrt(number);
}
private static void profile(Function<Long, Long> function, Long number) {
long result = 0L;
long startTime = System.nanoTime();
for (int i = 0; i < ITERATIONS; ++i) {
result = function.apply(number);
}
long endTime = System.nanoTime();
System.out.printf("Time: %.2f, result: %d.\n",
(endTime - startTime) / 1e6,
result);
}
private static final int ITERATIONS = 1_000;
private static final long UPPER_BOUND = 1_000_000_000_000L;
public static void main(String[] args) {
long seed = System.nanoTime();
Random random = new Random(seed);
long number = Math.abs(random.nextLong()) % UPPER_BOUND;
System.out.println("Seed = " + seed);
System.out.println("Number: " + number);
profile(Main::intSqrt1, number);
profile(Main::intSqrt2, number);
profile(Main::intSqrt3, number);
profile(Main::intSqrt4, number);
}
}
The performance figures I get looks like this:
Seed = 19608492647714
Number: 54383384696
Time: 531.18, result: 233202.
Time: 218.41, result: 233202.
Time: 1.81, result: 233202.
Time: 0.43, result: 233202.
Above, intSqrt3
took 1.81 milliseconds.
Critique request
Is there something I could improve? Naming/coding conventions? Performance? API design?
4 Answers 4
When using function parameters, use the primitive types when available:
Function<Long, Long> function
is a red-flag, and should be LongUnaryOperator
.
Your code will spin in to an infinite loop for 25% of all long values .... anything larger than Long.MAX_VALUE/4
will cause this loop to become infinite:
// Do the exponential search. while (4 * sqrt * sqrt <= number) { sqrt *= 2; }
About that loop.... why do you have a magic number 4
....? What does it do?
This code needs more testing... and magic numbers need to be removed.
-
\$\begingroup\$ The constant 4 makes sure that multiplying the current "guess" by a factor of 2 will not exceed the square root of
number
. And no, I did not expect that routine to be flawless. \$\endgroup\$coderodde– coderodde2016年01月17日 13:39:59 +00:00Commented Jan 17, 2016 at 13:39
You want something fast and efficient.
But did you really check what this method does :
public static long intSqrt1(long number) { long sqrt = 0L; while ((sqrt + 1) * (sqrt + 1) <= number) { sqrt++; } return sqrt; }
Your adding 1 to sqrt
3 times.
I don't see any reason why you should do that, but I'm guessing it's for the easy part for returning sqrt.
Let's refactor just this to a more efficient method.
First of all, a while loop where you need to count your steps, that's called a for loop.
public static long intSqrt1(long number) {
long sqrt;
for (sqrt = 1; (sqrt * sqrt) <= number; sqrt++) {}
return --sqrt;
}
This method is doing all the same but I do raise the sqrt
only once each time and if I return it, I will decrease it.
Now I did write some basic test, it's not a how a real performance test should be but in this case you will see the difference because it's big :
public static void main(String[] args) {
long startTime = System.nanoTime();
for (int i = 0; i < 10000; i++) {
intSqrt1(902545489); // new one
}
long midTime = System.nanoTime();
for (int j = 0; j < 10000; j++) {
intSqrt2(902545489); // old one
}
long endTime = System.nanoTime();
System.out.println((midTime - startTime) + " vs " + (endTime - midTime));
}
As you can see, the for
I initialize 2 time a new integer and I put the new method first because we can have a delay with the startup so there could be time faults in the first method.
Still I got this as output: (I put dot's for easy reading)
175.509.799 vs 360.087.176
As you can see I halved the time.
-
\$\begingroup\$
return sqrt - 1;
looks better I think (no need for assignation) \$\endgroup\$oliverpool– oliverpool2016年01月18日 14:39:00 +00:00Commented Jan 18, 2016 at 14:39 -
\$\begingroup\$ For me that's personal flavour, It's possible that - 1 is a lit faster but I don't know that. \$\endgroup\$chillworld– chillworld2016年01月18日 18:14:38 +00:00Commented Jan 18, 2016 at 18:14
Since you know Long.MAX_VALUE
in advance, you can hard-code it's square root. You can then perform a binary search between 1 and this pre-computed maximum.
It will remove the questionable "exponential search".
You will then also achieve a \$\mathcal{O}(\log \sqrt{Long.MAX\_VALUE})\$ complexity, which is actually \$\mathcal{O}(1)\$ as observed by @Simon Forsberg. This means that the execution duration can be bounded by a constant time (this does not necessary means that this algorithm is the fastest).
-
2\$\begingroup\$ I'd rather say that this will be
O(log(sqrt(CONSTANT))
, the complexity does not scale according to the input in this case. Technically, becausesqrt(CONSTANT)
is also a constant, andlog(constant)
is also a constant, you could even call thisO(1)
. \$\endgroup\$Simon Forsberg– Simon Forsberg2016年01月19日 15:51:09 +00:00Commented Jan 19, 2016 at 15:51 -
\$\begingroup\$ @SimonForsberg that's right, I updated my answer accordingly \$\endgroup\$oliverpool– oliverpool2016年01月19日 16:46:02 +00:00Commented Jan 19, 2016 at 16:46
-
\$\begingroup\$ Since the input is bounded, all algorithms are actually
O(1)
(it might just not be trivial to find the bound) \$\endgroup\$oliverpool– oliverpool2016年01月19日 16:47:22 +00:00Commented Jan 19, 2016 at 16:47 -
\$\begingroup\$ I think we proved that
P = NP
. Money, here we go! \$\endgroup\$coderodde– coderodde2016年01月26日 18:43:19 +00:00Commented Jan 26, 2016 at 18:43
For the hell of micro-benchmarking without the likes of Java Microbenchmarking Harness or jmicrobench(no idea whether this is official) (or that most visible one for those who don't have issues with empires), I tinkered around picking up ideas from rolfl and chillworld
private static void profile(UnaryLongFunc function, Random r) {
long
result = 0L,
iterations = function.count,
startTime = System.nanoTime();
StringBuilder results = new StringBuilder(123);
for (long i = 0; i < iterations; ++i) {
long t = r.nextLong() % UPPER_BOUND;
result = function.apply(t < 0 ? t + UPPER_BOUND : t);
if (i < 7)
results.append(", ").append(result);
}
long endTime = System.nanoTime();
System.out.printf("%-12s %11.2f%s\n", function.label,
(double)(endTime - startTime) / iterations, results);
}
// private static final int ITERATIONS = 1_000;
private static final long UPPER_BOUND = 1_000_000_000_000L;//Long.MAX_VALUE;
/** increment sqrt just once per iteration (chillworld) */
public static long intSqrt11(long number) {
long sqrt;
for (sqrt = 1; (sqrt * sqrt) <= number; sqrt++) {}
return --sqrt;
}
/** source-level strength reduction */
public static long intSqrt12(long number) {
for (long sq = 0, inc = 1 ; ; sq += inc, inc += 2)
if (number <= sq)
return (inc >>> 1) - 1;
}
/** int for increment & square up to Integer.MAX_VALUE */
static int SQUARE_LIMIT = (Integer.MAX_VALUE-Character.MAX_VALUE)>>>1;
public static long intSqrt13(long number) {
int sq = 0,
limit = (int) Math.min(SQUARE_LIMIT, number),
inc = 1;
while (sq < limit) {
sq += inc;
inc += 2;
}
if (number <= sq)
return (inc >>> 1) - 1;
long lsq = sq, linc = inc;
while (lsq < number) {
lsq += linc;
linc += 2;
}
return (linc >>> 1) - 1;
}
/** int for increment for number below Integer.MAX_VALUE**2
* & for square up to Integer.MAX_VALUE */
public static long intSqrt14(long number) {
int sq = 0,
limit = (int) Math.min(SQUARE_LIMIT, number),
inc = 1;
while (sq < limit) {
sq += inc;
inc += 2;
}
if (number <= sq)
return (inc >>> 1) - 1;
long lsq = sq,
longLimit = Math.min((Integer.MAX_VALUE-1L) * Integer.MAX_VALUE, number);
while (lsq < longLimit) {
lsq += inc;
inc += 2;
}
if (number <= lsq)
return (inc >>> 1) - 1;
long linc = inc;
while (lsq < number) {
lsq += linc;
linc += 2;
}
return (linc >>> 1) - 1;
}
static abstract class UnaryLongFunc { // imitates LongUnaryOperator
final String label;
long count;
UnaryLongFunc(String label, long callCount) {
this.label = label;
count = callCount;
}
abstract long apply(long to);
}
static UnaryLongFunc return1;
static UnaryLongFunc []candy = {
new UnaryLongFunc("Sqrt1", 5000) {
@Override long apply(long to) { return intSqrt1(to); }
},
new UnaryLongFunc("Sqrt11", 10000) {
@Override long apply(long to) { return intSqrt11(to); }
},
new UnaryLongFunc("Sqrt12", 10000) {
@Override long apply(long to) { return intSqrt12(to); }
},
new UnaryLongFunc("Sqrt13", 10000) {
@Override long apply(long to) { return intSqrt13(to); }
},
new UnaryLongFunc("Sqrt14", 10000) {
@Override long apply(long to) { return intSqrt14(to); }
},
new UnaryLongFunc("Sqrt2", 10000) {
@Override long apply(long to) { return intSqrt2(to); }
},
new UnaryLongFunc("Sqrt3", 20000000) {
@Override long apply(long to) { return intSqrt3(to); }
},
new UnaryLongFunc("Sqrt4", 100000000L) {
@Override long apply(long to) { return intSqrt4(to); }
},
// one call overhead less
return1 = new UnaryLongFunc("return 1", 1) {
@Override long apply(long to) { return 0; }
}
};
public static void main(String[] args) {
long seed = System.nanoTime();
System.out.println("Seed = " + seed);
for (long count = 1000 ; count <= 10000000L ; count *= 10) {
return1.count = count;
profile(return1, new Random(seed));
}
System.out.println("warmup ...");
long total = 0;
for (int i = Integer.MAX_VALUE / 500 ; 0 <= --i ; )
for (long l = 42 ; 0 < l ; l >>= 1)
for (UnaryLongFunc luo: candy)
total += luo.apply(l);
System.out.println(total);
System.gc();
for (int cd = 5 ; 0 <= --cd ; )
profile(return1, new Random(seed));
System.out.println(total);
for (int cc = candy.length, ulf = cc, change = -1 ; (ulf += change) < cc ; ) {
System.gc();
profile(candy[ulf], new Random(seed));
if (ulf <= 0)
change = 1;
}
}
In my environment, warmup of the empty/non-method made a ratio in execution time of about 100:1.
Explore related questions
See similar questions with these tags.
ret
? \$\endgroup\$