@@ -37,19 +37,22 @@ main = do
37
37
-- Fast Fourier Transform (FFT)
38
38
--
39
39
40
+ halve :: G. Vector vec a => vec a -> vec a
41
+ halve v = let n = G. length v
42
+ in G. generate (n `quot` 2 ) $ \ j -> v G. ! (j * 2 )
43
+
40
44
fft :: forall vec a . (Num a , G. Vector vec a )
41
- => vec a -- ^ For a primitive n-th root of unity @u@, @[1,u,u^2 .. u^(n-1)]@
45
+ => [ vec a ] -- ^ For a primitive n-th root of unity @u@, @iterate halve [1,u,u^2 .. u^(n-1)]@
42
46
-> vec a -- ^ a polynomial of length n (= 2^k for some k)
43
47
-> vec a
44
- fft u f | n == 1 = f
45
- | otherwise = let ! n2 = n `quot` 2
46
- r0 , r1' , u2 , t0 , t1' :: vec a
47
- r0 = G. generate n2 $ \ j -> (f G. ! j) + (f G. ! (j + n2))
48
- r1' = G. generate n2 $ \ j -> ((f G. ! j) - (f G. ! (j + n2))) * u G. ! j
49
- ! u2 = G. generate n2 $ \ j -> u G. ! (j * 2 )
50
- ! t0 = fft u2 r0
51
- ! t1' = fft u2 r1'
52
- in G. generate n $ \ j -> if even j then t0 G. ! (j `quot` 2 ) else t1' G. ! (j `quot` 2 )
48
+ fft (u: u2) f | n == 1 = f
49
+ | otherwise = let ! n2 = n `quot` 2
50
+ r0 , r1' , t0 , t1' :: vec a
51
+ r0 = G. generate n2 $ \ j -> (f G. ! j) + (f G. ! (j + n2))
52
+ r1' = G. generate n2 $ \ j -> ((f G. ! j) - (f G. ! (j + n2))) * u G. ! j
53
+ ! t0 = fft u2 r0
54
+ ! t1' = fft u2 r1'
55
+ in G. generate n $ \ j -> if even j then t0 G. ! (j `quot` 2 ) else t1' G. ! (j `quot` 2 )
53
56
where n = G. length f
54
57
55
58
mulFFT :: U. Vector Int -> U. Vector Int -> U. Vector Int
@@ -59,6 +62,7 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
59
62
n = bit k
60
63
u :: U. Vector (Complex Double )
61
64
u = U. generate n $ \ j -> cis (fromIntegral j * (2 * pi / fromIntegral n))
65
+ us = iterate halve u
62
66
f' = U. generate n $ \ j -> if j < U. length f then
63
67
fromIntegral (f U. ! j)
64
68
else
@@ -67,10 +71,10 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
67
71
fromIntegral (g U. ! j)
68
72
else
69
73
0
70
- f'' = fft u f'
71
- g'' = fft u g'
74
+ f'' = fft us f'
75
+ g'' = fft us g'
72
76
fg = U. generate n $ \ j -> (f'' U. ! j) * (g'' U. ! j)
73
- fg' = fft (U. map conjugate u ) fg
77
+ fg' = fft (map ( U. map conjugate) us ) fg
74
78
in U. generate n $ \ j -> round (realPart (fg' U. ! j) / fromIntegral n)
75
79
76
80
--
0 commit comments