Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 01029b4

Browse files
committed
ATC001-C: Implement NTT
1 parent a8bf065 commit 01029b4

File tree

3 files changed

+773
-13
lines changed

3 files changed

+773
-13
lines changed

‎atc001-c/Main.hs‎

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,22 @@ main = do
3737
-- Fast Fourier Transform (FFT)
3838
--
3939

40+
halve :: G.Vector vec a => vec a -> vec a
41+
halve v = let n = G.length v
42+
in G.generate (n `quot` 2) $ \j -> v G.! (j * 2)
43+
4044
fft :: forall vec a. (Num a, G.Vector vec a)
41-
=> vec a -- ^ For a primitive n-th root of unity @u@, @[1,u,u^2 .. u^(n-1)]@
45+
=> [vec a] -- ^ For a primitive n-th root of unity @u@, @iterate halve [1,u,u^2 .. u^(n-1)]@
4246
-> vec a -- ^ a polynomial of length n (= 2^k for some k)
4347
-> vec a
44-
fft u f | n == 1 = f
45-
| otherwise = let !n2 = n `quot` 2
46-
r0, r1', u2, t0, t1' :: vec a
47-
r0 = G.generate n2 $ \j -> (f G.! j) + (f G.! (j + n2))
48-
r1' = G.generate n2 $ \j -> ((f G.! j) - (f G.! (j + n2))) * u G.! j
49-
!u2 = G.generate n2 $ \j -> u G.! (j * 2)
50-
!t0 = fft u2 r0
51-
!t1' = fft u2 r1'
52-
in G.generate n $ \j -> if even j then t0 G.! (j `quot` 2) else t1' G.! (j `quot` 2)
48+
fft (u:u2) f | n == 1 = f
49+
| otherwise = let !n2 = n `quot` 2
50+
r0, r1', t0, t1' :: vec a
51+
r0 = G.generate n2 $ \j -> (f G.! j) + (f G.! (j + n2))
52+
r1' = G.generate n2 $ \j -> ((f G.! j) - (f G.! (j + n2))) * u G.! j
53+
!t0 = fft u2 r0
54+
!t1' = fft u2 r1'
55+
in G.generate n $ \j -> if even j then t0 G.! (j `quot` 2) else t1' G.! (j `quot` 2)
5356
where n = G.length f
5457

5558
mulFFT :: U.Vector Int -> U.Vector Int -> U.Vector Int
@@ -59,6 +62,7 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
5962
n = bit k
6063
u :: U.Vector (Complex Double)
6164
u = U.generate n $ \j -> cis (fromIntegral j * (2 * pi / fromIntegral n))
65+
us = iterate halve u
6266
f' = U.generate n $ \j -> if j < U.length f then
6367
fromIntegral (f U.! j)
6468
else
@@ -67,10 +71,10 @@ mulFFT !f !g = let n' = U.length f + U.length g - 2
6771
fromIntegral (g U.! j)
6872
else
6973
0
70-
f'' = fft u f'
71-
g'' = fft u g'
74+
f'' = fft us f'
75+
g'' = fft us g'
7276
fg = U.generate n $ \j -> (f'' U.! j) * (g'' U.! j)
73-
fg' = fft (U.map conjugate u) fg
77+
fg' = fft (map (U.map conjugate) us) fg
7478
in U.generate n $ \j -> round (realPart (fg' U.! j) / fromIntegral n)
7579

7680
--

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /