I wrote a function to find all the possible permutations of a list in Haskell. I know it can definitely be optimized, but I'm not sure how. I'm pretty sure foldl
or foldl'
can be used, but I'm not sure how. It gets kinda slow when the size of the argument for perms
is more than 6 items, but I don't know if this is avoidable.
What could I do to improve this function, mainly to simply it, improve it stylistically, and boost performance?
perms :: [a] -> [[a]]
perms = perms' 0
where
perms' _ [x, y] = [[x, y], [y, x]]
perms' c xs
| c == length xs = []
| otherwise = (sub_perm xs) ++ (perms' (c + 1) (shift xs))
sub_perm (x:xs) = fmap (\a -> x:a) $ perms xs
shift xs = (last xs):(init xs)
2 Answers 2
It's better to base recursion at the length 1 or 0. It's usually trivial and reduces the chance of making an error. In your case, the code doesn't work for lists of length 1, and this can be easily fixed by setting the base case to perms' _ [x] = [[x]]
.
The costly operations in your code are repeated traversals of the input list. In particular, length xs
is called every time, and as lists in Haskell are lazy linked lists, it costs you O(n). You could pass the length of the list as another argument instead.
Similarly last
and init
are O(n). You could use splitAt
to traverse the list just once, or even better, rotate the other way around, something like shift (y:ys) = ys ++ [y]
where you need to traverse the list just once (for ++
) and pattern matching is also somewhat safer than using partial functions such as init
/head
/last
/..., especially if you cover all cases and use -fwarn-incomplete-patterns
. You might also consider using Seq
which has O(1) costs for manipulating its ends and O(log n) splitting/merging sequences in the middle, but has higher constant factor.
Another source of inefficiencies could be the ++
in the otherwise
branch, as ++
needs to traverse the whole left argument. You might again try out Seq
, or constructing the result using difference lists, which eliminates this problem.
You could solve several of these problems by introducing a helper function that'd return all possible splits of an input list, something like
splits :: [a] -> [(a, [a])]
for example splits [1,2,3]
= [(1, [2, 3]), (2, [1, 3]), (3, [1, 2])]
. And then recursively process the second part, prepending the picked element to all sub-results.
It's good that you provided the type of the top-level function.
Also (\a -> x : a)
can be abbreviated to (x :)
using η-reduction.
Below is code based on the above ideas, with some more optimizations (to improve sub-list sharing), left as an exercise to analyze:
perms :: [a] -> [[a]] perms = go [] where go rs [] = [rs] go rs xs = concatMap (\(y, ys) -> go (y : rs) ys) (splits xs) splits :: [a] -> [(a, [a])] splits = go [] where go ys [] = [] go ys (x : xs) = (x, ys ++ xs) : go (x : ys) xs
You might be interested in the permutations
function from Data.List:
-- | The 'permutations' function returns the list of all permutations of the argument.
--
-- > permutations "abc" == ["abc","bac","cba","bca","cab","acb"]
permutations :: [a] -> [[a]]
permutations xs0 = xs0 : perms xs0 []
where
perms [] _ = []
perms (t:ts) is = foldr interleave (perms ts (t:is)) (permutations is)
where interleave xs r = let (_,zs) = interleave' id xs r in zs
interleave' _ [] r = (ts, r)
interleave' f (y:ys) r = let (us,zs) = interleave' (f . (y:)) ys r
in (y:us, f (t:y:us) : zs)
-
\$\begingroup\$ That's not a review of the author's code. \$\endgroup\$Zeta– Zeta2016年07月15日 06:12:53 +00:00Commented Jul 15, 2016 at 6:12