I created the following function, permutations
, to produce all permutations of a List[A]
.
Example:
scala> net.Permutations.permutations("ab".split("").toList)
res3: List[List[String]] = List(List(a, b), List(b, a))
Code:
object Permutations {
def permutations[A](str: List[A]): List[List[A]] =
str match {
case Nil => List(Nil)
case list @ _ :: _ =>
val shifteds: List[List[A]] =
shiftN(list, list.length)
shifteds.flatMap {
case head :: tail =>
permutations(tail).map { lists: List[A] =>
head :: lists
}
case Nil => Nil
}
}
private def shiftN[A](list: List[A], n: Int): List[List[A]] = {
if (n <= 0) Nil
else {
val shifted: List[A] = shift(list)
shifted :: shiftN(shifted, n - 1)
}
}
private def shift[A](arr: List[A]): List[A] = arr match {
case head :: tail => tail ++ List(head)
case Nil => Nil
}
}
I think it's correct since the following property-based check succeeds:
import munit.ScalaCheckSuite
import org.scalacheck.Prop._
import org.scalacheck.Gen
class PermutationsSpec extends ScalaCheckSuite {
private val listGen: Gen[List[Int]] =
for {
n <- Gen.choose(0, 7)
list <- Gen.listOfN(n, Gen.posNum[Int])
} yield list
property("permutations works") {
forAll(listGen) { list: List[Int] =>
val mine: List[List[Int]] = Permutations.permutations(list)
val stdLib: List[List[Int]] = list.permutations.toList
assert(stdLib.diff(mine).isEmpty)
}
}
}
Please evaluate for correctness, concision and performance.
2 Answers 2
The code looks to be correct, yes, the testing looks good too. I was
thinking whether List(List())
for the input List()
makes sense, but
it seems like that's a sensible output.
For the code readability I'd rename str
, especially since it's not
really a string, but a list
. The complicated match expression in
permutations
can just be simplified to case _
and in the body the
original argument can be reused again.
I'd also inline shifteds
value since it's just a single call and the
name doesn't really tell me anything. On that note, docstrings for the
functions might be a nice touch, especially for the shift
and shiftN
methods.
The unused case labels can also just be _
everywhere. Depends of
course, for me this makes it clearer that really there's always just two
cases, either matching an empty list, or a non-empty one, there's no
third case.
Would look like this then:
object Permutations {
def permutations[A](list: List[A]): List[List[A]] =
list match {
case Nil => List(Nil)
case _ =>
shiftN(list, list.length).flatMap {
case head :: tail =>
permutations(tail).map(head :: _)
case _ => Nil
}
}
def shiftN[A](list: List[A], n: Int): List[List[A]] = {
if (n <= 0) Nil
else {
val shifted: List[A] = shift(list)
shifted :: shiftN(shifted, n - 1)
}
}
def shift[A](list: List[A]): List[A] = list match {
case head :: tail => tail ++ List(head)
case _ => Nil
}
}
Lastly, performance-wise it depends what your constraints are: For
List
input and List
output, restricting it to single linked lists,
this is fine, though I haven't benchmarked them of
course. Potentially converting the shifted :: shiftN(...)
call into
using an accumulator might be worth a bit, instead of having a deep call
stack, but again, it'll probably only matter for longer inputs.
But there are way quicker algorithms, though you might want to copy the input into a vector that can be accessed in constant time for each index. (I found the QuickPerm algorithm, as explained here absolutely straightforward to implement from scratch.)
Here is a another solution using fold that has the similar performance characteristics. One difference is that it doesn't have to compute the linear List Length
def inserts[A](x: A): List[A] => List[List[A]] =
ls =>
ls match {
case Nil => List(List(x))
case (y :: ys) => (x :: y :: ys) :: (inserts(x)(ys)).map( y :: _)
}
def permutations[A](ls: List[A]): List[List[A]] =
ls.foldRight(List(List[A]()))((x, xss) => xss.flatMap(inserts(x)))
-
2\$\begingroup\$ Welcome to CodeReview@SE. I see an independent solution proposed.
doesn't have to compute [list length]
is a bit thin on insight about the code presented for review - CR is not about insight about the problem the code presented is to solve. \$\endgroup\$greybeard– greybeard2021年02月24日 09:16:24 +00:00Commented Feb 24, 2021 at 9:16
permutations(List(1,2,2)).length
vsList(1,2,2).permutations.length
. Is that intentional? \$\endgroup\$