The code below is a Haskell implementation of Needleman-Wunsch algorithm for sequence alignment (and string edit distance). It's an experiment in trying to closely imitate the dynamic programming method as it would be implemented in imperative languages. Hence my use of mutable arrays and ST monad.
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE LambdaCase #-}
import Data.Array.ST
import Data.List
import Control.Monad.ST
import Control.Monad
-- Primitive edit operations
data Operation a = Subst a a -- match is a special case of substitution
| Insert a
| Delete a
deriving (Show, Eq)
-- Operation costs
data Costs a = Costs {
delete :: a -> Int,
insert :: a -> Int,
subst :: a -> a -> Int
}
defaultCost :: (Eq a) => Costs a
defaultCost = Costs {
delete = const 1, insert = const 1,
subst = \x y -> if x == y
then 0
else 10
}
data Transform a = Transform [Operation a] Int
-- Given cost functions, describe a series of transformation taking the input string to the output string
alignWithCost :: Costs a
-> [a] -- ^ input string
-> [a] -- ^ output string
-> Transform a
alignWithCost Costs{..} inputString outputString = let
inputLen = length inputString
outputLen = length outputString
alignWithCostST = do
array <- newArray ((0, 0), (inputLen, outputLen)) $ Transform [] 0 :: ST s (STArray s (Int, Int) (Transform a))
-- To get from x1 ... xn to the empty string, perform n deletions
-- the sequence of transforms is placed in reverse order ; later transform can then be appended with (:)
let initials = map reverse $ inits inputString
forM_ (zip [0 .. inputLen] initials) $ \(i, initial) -> do
writeArray array (i, 0) $ Transform
(map Delete initial)
(sum $ map delete initial)
-- To get from the empty string to x1 ... xn to the empty string, perform n insertions
let initials = map reverse $ inits outputString
forM_ (zip [0 .. outputLen] initials) $ \(i, initial) -> do
writeArray array (0, i) $ Transform
(map Insert initial)
(sum $ map insert initial)
-- To get from x1 ... xn to y1 ... ym, you either
-- * go from x1 ... xn to y1 ... y(m-1) and insert ym (top)
-- * delete xn and go from x1 ... x(n-1) to y1 ... ym (left)
-- * substitute xn for ym go from x1 ... x(n-1) to y1 ... y(m-1) (diag)
-- Pick whichever way is least costly
forM_ (zip [1 .. inputLen] inputString) $ \(i, cIn) -> do
forM_ (zip [1 .. outputLen] outputString) $ \(j, cOut) -> do
Transform alignmentSeqLeft initCostLeft <- readArray array (i - 1, j)
Transform alignmentSeqTop initCostTop <- readArray array (i, j - 1)
Transform alignmentSeqDiag initCostDiag <- readArray array (i - 1, j - 1)
let costLeft = initCostLeft + delete cIn
costTop = initCostTop + insert cOut
costDiag = initCostDiag + subst cIn cOut
minCost = min costLeft $ min costTop costDiag
let toWrite
| minCost == costLeft = (Delete cIn):alignmentSeqLeft
| minCost == costTop = (Insert cOut):alignmentSeqTop
| minCost == costDiag = (Subst cIn cOut):alignmentSeqDiag
writeArray array (i, j) $ Transform toWrite minCost
Transform operations cost <- readArray array (inputLen, outputLen)
-- Reverse the sequence of transforms to original order: earlier transforms before
return $ Transform (reverse operations) cost
in runST alignWithCostST
align :: Costs a
-> [a] -- ^ input string
-> [a] -- ^ output string
-> [Operation a]
align costs inputString outputString = operations
where Transform operations _ = alignWithCost costs inputString outputString
alignDefault :: (Eq a)
=> [a] -- ^ input string
-> [a] -- ^ output string
-> [Operation a]
alignDefault = align defaultCost
printAlignSeq :: (a -> Char)
-> [Operation a]
-> IO ()
printAlignSeq display tokens = do
forM_ tokens $ \case
Insert _ -> putStr $ "_"
Delete x -> putStr $ [display x]
Subst x _ -> putStr $ [display x]
putStrLn ""
forM_ tokens $ \case
Insert x -> putStr $ [display x]
Delete _ -> putStr $ "_"
Subst x _ -> putStr $ [display x]
putStrLn ""
Here are some test runs:
*Main> printAlignSeq id $ alignDefault "a beautiful bike" "a big bike"
a beauti_ful bike
a b____ig___ bike
*Main> printAlignSeq id $ alignDefault "accgcag" "ccgacag"
accg_cag
_ccgacag
*Main> printAlignSeq id $ alignDefault "sinusoidal" "cosine"
__sin_usoidal
cosine_______
Any comments or feedback is welcome! I am particularly interested in the following points:
- Is this a good way to implement dynamic programming algorithms in Haskell?
- Is this a good way to mimic mutability in a language like Haskell?
- I am worried about memory issues: is there any unnecessary thunk build-up? how can I tell? (I really struggle to understand that!)
1 Answer 1
I wouldn't worry about using mutability here. Sometimes it pays to just be straightforward and stick closely to something that you know works.
You're not using
ScopedTypeVariables
.As far as the interface,
Costs
has no reason to be adata
type. It is clearer from the user's perspective to say simplytype Costs a = Operation a -> Int -- replacements for record accessors -- delete costs x = costs (Delete x) -- insert costs x = costs (Insert x) -- subst costs x y = costs (Subst x y) -- e.g. defaultCost :: Eq a => Costs a defaultCost (Subst x y) | x == y = 0 | otherwise = 10 defaultCost (Insert _) = 1 defaultCost (Delete _) = 1
No more
RecordWildCards
.Transform
currently also has no reason to exist: the only point it might have is documentation, but you haven't given it any! TheInt
is particularly mysterious at first sight. You should make it into a record with named fields (which will probably be self documenting enough). While you're at it, the total cost should be strict (I believe your code has a minor space leak in this field). "Accumulator" fields often should be strict.data Transform a = Transform { transformOps :: [Operation a], transformCost :: !Int }
Taking a look at the main function
alignWithCost costs input output = runST $ do -- no need to do strange things with lets let inputLen = length input outputLen = length output -- (array? really?) -- (if you're going to comment about how the algorithm works, you should explain what this is) -- each partials[i, j] will eventually be the minimal transform from (take i input) to (take j output) -- also, the transformOps of the partials are reversed for efficient appending with (:) partials <- newArray ((0, 0), (inputLen, outputLen)) $ Transform [] 0 :: ST s (STArray s (Int, Int) (Transform a)) let -- (a little vocabulary goes a long way) addOp op (Transform ops totalCost) = Transform (op : ops) (costs op + totalCost) -- (zip truncates; I find zipping with the infinite list cleaner) -- (there is no need to reverse each of the inits separately when we can write this iteratively) -- (your wasted space was here: I believe each Transform made in these loops would hang onto the corresponding list from the inits until you would start evaluating things down in the nested loop) -- (in this version, without the strictness annotation from before, this would still build some annoying thunks in the cost field) -- (compared to the total space usage though, I suspect it might be moot either way) -- for each partials[i, 0]: to get from x1 ... xi to the empty string, perform i deletions forM_ (zip [1..] input) $ \(i, del) -> -- (if we're mimicking imperative languages, may as well use all the application operators we have to sell it ;)) writeArray partials (i, 0) =<< addOp (Delete del) <$> readArray partials (i - 1, 0) -- for each partials[0, i]: to get from the empty string to y1 .. yi, perform i insertions forM_ (zip [1..] output) $ \(i, ins) -> writeArray partials (0, i) =<< addOp (Insert ins) <$> readArray partials (0, i - 1) -- for all the rest of the partials[i > 0, j > 0]: -- to get from x1 ... xi to y1 ... yj, either: -- * transform x1 ... x(i-1) to y1 ... yj, then delete xi (using partials[i - 1, j] = "left") -- * transform x1 ... xi to y1 ... y(j-1), then insert yj (using partials[i, j - 1] = "up") -- * transform x1 ... x(i-1) to y1 ... y(j-1), then substitute xi with yj (using partials[i - 1, j - 1] = "diag") -- take the one with minimal cost forM (zip [1..] input) $ \(i, xi) -> forM (zip [1..] output) $ \(j, yj) -> do -- (so! very! clean!) left <- addOp (Delete xi ) <$> readArray partials (i - 1, j ) up <- addOp (Insert yj) <$> readArray partials (i , j - 1) diag <- addOp (Subst xi yj) <$> readArray partials (i - 1, j - 1) writeArray partials (i, j) $ minimumBy (comparing transformCost) [left, up, diag] Transform ops cost <- readArray partials (inputLen, outputLen) return $ Transform (reverse ops) cost
Having written all that, I do think that perhaps mutability actually is entirely unnecessary here. We can just use a single immutable array with an intricate recursion pattern:
alignWithCost costs input output = Transform (reverse ops) cost where inputLen = length input outputLen = length output xs = listArray (0, inputLen - 1) input ys = listArray (0, outputLen - 1) output -- partials[i, j] is the minimum cost transform from (take i input) to (take j output) -- the transformOps are stored in reverse for efficient appending with (:) partials = array ((0, 0), (inputLen, outputLen)) [((i, j), partial i j) | i <- [0..inputLen], j <- [0..outputLen]] addOp op (Transform ops cost) = Transform (op : ops) (costs op + cost) -- this calculates each of the partials[i, j], potentially in terms of some of the partials[n < i, m < j] partial 0 0 = Transform [] 0 -- empty string to empty string partial i 0 = addOp (Delete $ xs ! (i - 1)) $ partials ! (i - 1, 0) -- (take i input) to []: delete each element partial 0 j = addOp (Insert $ ys ! (j - 1)) $ partials ! (0, j - 1) -- [] to (take j output): insert each element partial i j = minimumBy (comparing transformCost) [left, up, diag] -- (take (i - 1) input ++ [del]) to (take (j - 1) output ++ [ins]); choose of the following the minimal cost: where del = xs ! (i - 1) ins = ys ! (j - 1) left = addOp (Delete del ) $ partials ! (i - 1, j ) -- transform (take (i - 1) input) to (take j output) and delete del up = addOp (Insert ins) $ partials ! (i , j - 1) -- transform (take i input) to (take (j - 1) output) and insert ins diag = addOp (Subst del ins) $ partials ! (i - 1, j - 1) -- transform (take (i - 1) input) to (take (j - 1) output) and substitute del with ins Transform ops cost = partials ! (inputLen, outputLen)
Yes, that's legal! (And it seems to work.) This is basically getting to heart of what memoization/dynamic programming is:
partial
is a recursive algorithm "at heart", but we've simply taken the recursive calls and redirected them to a lookup table (the transformationpartials ! (x, y)
<->partial x y
). In Haskell, we can trust laziness to populate that table as needed instead of writing it ourselves. We use an array over a list because lists are not meant for random access.For the helper functions, I can't say much, except that now
align
can just bealign costs input output = transformOps $ alignWithCost costs input output
printAlignSeq
can probably beprintAlignSeq display ops = do putStrLn $ flip map ops $ \case Insert _ -> '_' Delete x -> display x Subst x _ -> display x putStrLn $ flip map ops $ \case Insert y -> display y Delete _ -> '_' Subst _ y -> display y
-
\$\begingroup\$ Thanks for taking the time to write this excellent review! I can see the benefit of array+recursion+laziness to get some form of memoization going on. I'd be interested to know how this solution compares to the ST-based solution in terms of memory footprint. \$\endgroup\$Ahmad B– Ahmad B2021年06月04日 16:01:03 +00:00Commented Jun 4, 2021 at 16:01