Given an array, find any three numbers which sum to zero.
This problem is a continuation of a code kata in Python . I'm learning Scala and would like feedback on more efficient, cleaner, and functional ways of solving (and unit testing) this problem.
SumToZero.scala
class SumToZeroNotFound(msg:String=null, cause:Throwable=null)
extends java.lang.Exception (msg, cause) {}
class SumToZero{
def process(arr: Array[Int], size: Int = 3): Array[Int] = {
for(combination <- arr.combinations(size))
if(combination.sum == 0) //note: can yield combinations here if we want to find all combinations that sum to zero
return combination
throw new SumToZeroNotFound()
}
}
TestSumToZero.scala
import org.junit.runner.RunWith
import org.scalatest.junit._
import org.scalatest._
import scala.runtime.ScalaRunTime.stringOf
@RunWith(classOf[JUnitRunner])
class TestSumToZero extends FunSuite {
test("should return entire array that sums to zero") {
val stz = new SumToZero
val res = stz.process(Array(1,2,-3))
val str_res = stringOf(res)
assert( str_res == "Array(1, 2, -3)",
"res: " + str_res)
}
test("should raise NotFound exception if no combination sums to zero"){
val stz = new SumToZero
intercept[SumToZeroNotFound]{
stz.process(Array(1,2,3,4))
}
}
test("should return first set of nums that sum to zero"){
val stz = new SumToZero
val res = stz.process(Array(125, 124, -100, -25, 1, 2, -3, 10, 100))
val str_res = stringOf(res)
assert( str_res == "Array(125, -100, -25)",
"res: " + str_res)
}
}
Updated SumToZero.scala
class SumToZeroNotFound(msg:String=null, cause:Throwable=null)
extends java.lang.Exception (msg, cause) {}
trait SumToZero{
@throws(classOf[SumToZeroNotFound])
def process(arr: Array[Int], size: Int = 3): Array[Int]
}
object SumToZeroBruteForce extends SumToZero{
//brute force process any size combination
override def process(arr: Array[Int], size: Int = 3): Array[Int] = {
arr.combinations(size).find(_.sum == 0).getOrElse(throw new SumToZeroNotFound)
}
}
object SumToZeroBinSearch extends SumToZero{
//optimized with sorting and binary search
//scala does not seem to have a native binary search routine
implicit private class Search(val arr: Array[Int]){
def binSearch(target: Int) = {
java.util.Arrays.binarySearch(arr.asInstanceOf[Array[Int]], target)
}
}
@throws(classOf[IllegalArgumentException])
override def process(arr: Array[Int], size: Int = 3): Array[Int] = size match {
case 3 =>
val sorted_arr = arr.sortWith(_<_)
val max_len = sorted_arr.length
for((i_val, i) <- sorted_arr.takeWhile(_ <= 0).zipWithIndex ){
val maxj = sorted_arr(max_len - 1) - i_val
for(j_val <- sorted_arr.slice(i+1, max_len).takeWhile(_ <= maxj)){
val temp_sum = i_val + j_val
val res_idx = sorted_arr.binSearch(-temp_sum)
if (res_idx > -1) {
return Array(i_val, j_val, sorted_arr(res_idx))
}
}
}
throw new SumToZeroNotFound
case _ => throw new IllegalArgumentException("only support size of three for now")
}
}
object SumToZero{
def apply(impl: String) = impl match{
case "with_bin_search" => SumToZeroBinSearch
case _ => SumToZeroBruteForce
}
}
Updated TestSumToZero
import org.junit.runner.RunWith
import org.scalatest.junit._
import org.scalatest._
import scala.runtime.ScalaRunTime.stringOf
@RunWith(classOf[JUnitRunner])
class TestSumToZero extends FunSuite{
final val methods = Array("brute_force", "with_bin_search")
def testAllMethods(arr: Array[Int], expected_res: String = "",
methods: Array[String] = this.methods) = {
for (method <- methods){
println("testing: " + method)
val stz = SumToZero(method)
val res = stz.process(arr)
val str_res = stringOf(res.sortWith(_<_))
assert( str_res == expected_res,
"res: " + str_res)
}
}
test("should return entire array that sums to zero") {
testAllMethods(Array(1,2,-3), "Array(-3, 1, 2)")
}
test("should raise NotFound exception if no combination sums to zero"){
this.methods.foreach(method =>
intercept[SumToZeroNotFound]{
testAllMethods(Array(1,2,3), "", Array(method))
})
}
test("should return first set of nums that sum to zero"){
testAllMethods(Array(125, 124, -100, -25, 1, 2, -3, 10, 100), "Array(-100, -25, 125)")
}
}
I'd like to use Scala spec for unit testing, but was not able to get it working with Gradle. Please comment if you are able to build Scala projects with Gradle and produce JUnit-style output.
-
\$\begingroup\$ Thanks for the feedback. I was not able to package functions outside an object or class. I also ended up playing around with Scala traits and pattern matching to test out multiple implementations of SumToZero. I've added the updated code above and in the repo. Maybe there's a more functional way to implement the optimized version... I'll have to tackle that another time. \$\endgroup\$Flack– Flack2013年12月29日 17:44:53 +00:00Commented Dec 29, 2013 at 17:44
2 Answers 2
Since this is tagged 'algorithm' I am going to review your algorithm....
Your algorithm here is wrong. Looking for all combinations (O(n!)
) is a very expensive way of doing it. I know very little about Scala, but, can assure you that a much better (the best?) algorithm would be something like:
- sort the data (
O(n log(n))
) - 'first' loop through all the values
- 'nested' loop through all values after the value in the 'first' loop
- binary search for the value
- (first + nested)
in all the values after the nested value
In Java it would be something like:
Arrays.sort(data);
for (int i = 0; i < data.length; i++) {
for (int j = i + 1; j < data.length; j++) {
int pos = Arrays.binarySearch(data, j + 1, data.length, - (data[i] + data[j]));
if (pos >= 0) {
System.out.printf("Zero-sum values are [%d,%d,%d]\n", data[i], data[j], data[pos]);
}
}
}
This has a much better performance characteristic than checking every combination.
You can also add some tricks / optimizations to the loops to quit when impossible combinations present themselves, like:
if (data[i] > 0) {
// if the smallest value is positive then no possible combination exists....
break;
}
additionally, you can short-circuit the 'j' loop....
for (int i = 0; i < data.length; i++) {
if (data[i] > 0) {
break;
}
int maxj = data[data.length - 1] - data[i];
for (int j = i + 1; j < data.length && data[j] <= maxj; j++) {
.......
}
}
how you do this in Scala is your problem .... ;-)
-
1\$\begingroup\$ the algorithm is also named Leetcode 15 3 sum, I wrote the idea you presented, the time complexity is O(nnlogn), n is the size of array. Leetcode online judge shows TLE - time limited exceeded. The optimal solution is not using binary search, using two pointers instead, the time complexity can be lowered to O(n*n). \$\endgroup\$Jianmin Chen– Jianmin Chen2016年12月27日 04:47:09 +00:00Commented Dec 27, 2016 at 4:47
If you only want to return the first such combination what about doing something like this instead of the for loop?
def process(arr: Array[Int], size: Int = 3): Array[Int] = {
arr.combinations(size).find(_.sum == 0).getOrElse(throw new SumToZeroNotFound)
}
I'd also suggest that if there's no reason to include it in a class that I personally would not. Unlike Java, not everything needs to be contained in a class and this seems like something which is in a class for the sake of being in a class.