1
\$\begingroup\$

I created the following function, permutations, to produce all permutations of a List[A].

Example:

scala> net.Permutations.permutations("ab".split("").toList)
res3: List[List[String]] = List(List(a, b), List(b, a))

Code:

object Permutations {
 def permutations[A](str: List[A]): List[List[A]] =
 str match {
 case Nil => List(Nil)
 case list @ _ :: _ =>
 val shifteds: List[List[A]] =
 shiftN(list, list.length)
 shifteds.flatMap {
 case head :: tail =>
 permutations(tail).map { lists: List[A] =>
 head :: lists
 }
 case Nil => Nil
 }
 }
 private def shiftN[A](list: List[A], n: Int): List[List[A]] = {
 if (n <= 0) Nil
 else {
 val shifted: List[A] = shift(list)
 shifted :: shiftN(shifted, n - 1)
 }
 }
 private def shift[A](arr: List[A]): List[A] = arr match {
 case head :: tail => tail ++ List(head)
 case Nil => Nil
 }
}

I think it's correct since the following property-based check succeeds:

import munit.ScalaCheckSuite
import org.scalacheck.Prop._
import org.scalacheck.Gen
class PermutationsSpec extends ScalaCheckSuite {
 private val listGen: Gen[List[Int]] =
 for {
 n <- Gen.choose(0, 7)
 list <- Gen.listOfN(n, Gen.posNum[Int])
 } yield list
 property("permutations works") {
 forAll(listGen) { list: List[Int] =>
 val mine: List[List[Int]] = Permutations.permutations(list)
 val stdLib: List[List[Int]] = list.permutations.toList
 assert(stdLib.diff(mine).isEmpty)
 }
 }
}

Please evaluate for correctness, concision and performance.

asked Feb 23, 2021 at 4:15
\$\endgroup\$
1
  • \$\begingroup\$ Your definition of permutation differs from that of the Scala standard library: permutations(List(1,2,2)).length vs List(1,2,2).permutations.length. Is that intentional? \$\endgroup\$ Commented Feb 24, 2021 at 4:39

2 Answers 2

1
\$\begingroup\$

The code looks to be correct, yes, the testing looks good too. I was thinking whether List(List()) for the input List() makes sense, but it seems like that's a sensible output.

For the code readability I'd rename str, especially since it's not really a string, but a list. The complicated match expression in permutations can just be simplified to case _ and in the body the original argument can be reused again.

I'd also inline shifteds value since it's just a single call and the name doesn't really tell me anything. On that note, docstrings for the functions might be a nice touch, especially for the shift and shiftN methods.

The unused case labels can also just be _ everywhere. Depends of course, for me this makes it clearer that really there's always just two cases, either matching an empty list, or a non-empty one, there's no third case.

Would look like this then:

object Permutations {
 def permutations[A](list: List[A]): List[List[A]] =
 list match {
 case Nil => List(Nil)
 case _ =>
 shiftN(list, list.length).flatMap {
 case head :: tail =>
 permutations(tail).map(head :: _)
 case _ => Nil
 }
 }
 def shiftN[A](list: List[A], n: Int): List[List[A]] = {
 if (n <= 0) Nil
 else {
 val shifted: List[A] = shift(list)
 shifted :: shiftN(shifted, n - 1)
 }
 }
 def shift[A](list: List[A]): List[A] = list match {
 case head :: tail => tail ++ List(head)
 case _ => Nil
 }
}

Lastly, performance-wise it depends what your constraints are: For List input and List output, restricting it to single linked lists, this is fine, though I haven't benchmarked them of course. Potentially converting the shifted :: shiftN(...) call into using an accumulator might be worth a bit, instead of having a deep call stack, but again, it'll probably only matter for longer inputs.

But there are way quicker algorithms, though you might want to copy the input into a vector that can be accessed in constant time for each index. (I found the QuickPerm algorithm, as explained here absolutely straightforward to implement from scratch.)

answered Feb 24, 2021 at 10:03
\$\endgroup\$
-1
\$\begingroup\$

Here is a another solution using fold that has the similar performance characteristics. One difference is that it doesn't have to compute the linear List Length


 def inserts[A](x: A): List[A] => List[List[A]] =
 ls =>
 ls match {
 case Nil => List(List(x))
 case (y :: ys) => (x :: y :: ys) :: (inserts(x)(ys)).map( y :: _)
 }
 def permutations[A](ls: List[A]): List[List[A]] =
 ls.foldRight(List(List[A]()))((x, xss) => xss.flatMap(inserts(x)))
answered Feb 24, 2021 at 6:36
\$\endgroup\$
1
  • 2
    \$\begingroup\$ Welcome to CodeReview@SE. I see an independent solution proposed. doesn't have to compute [list length] is a bit thin on insight about the code presented for review - CR is not about insight about the problem the code presented is to solve. \$\endgroup\$ Commented Feb 24, 2021 at 9:16

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.