Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bf187ef

Browse files
committed
TDPC-T: Karatsuba
1 parent 260b53c commit bf187ef

File tree

1 file changed

+280
-0
lines changed

1 file changed

+280
-0
lines changed

‎tdpc-t/Karatsuba.hs‎

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
-- https://github.com/minoki/my-atcoder-solutions
2+
{-# LANGUAGE BangPatterns #-}
3+
{-# LANGUAGE DataKinds #-}
4+
{-# LANGUAGE NoStarIsType #-}
5+
{-# LANGUAGE ScopedTypeVariables #-}
6+
{-# LANGUAGE TypeFamilies #-}
7+
{-# LANGUAGE TypeOperators #-}
8+
import Control.Monad
9+
import Control.Monad.ST
10+
import Data.Bits
11+
import qualified Data.ByteString.Char8 as BS
12+
import Data.Char (isSpace)
13+
import Data.Coerce
14+
import Data.Int (Int64)
15+
import Data.List (foldl', tails, unfoldr)
16+
import qualified Data.Vector.Generic as G
17+
import qualified Data.Vector.Unboxing as U
18+
import qualified Data.Vector.Unboxing.Mutable as UM
19+
import GHC.TypeNats (type (+), KnownNat, Nat,
20+
type (^), natVal)
21+
22+
type P = U.Vector (IntMod (10^9 + 7))
23+
type PM s = UM.MVector s (IntMod (10^9 + 7))
24+
25+
{-
26+
sum' :: KnownNat m => [IntMod m] -> IntMod m
27+
sum' = fromIntegral . foldl' (\x y -> x + unwrapN y) 0
28+
{-# INLINE sum' #-}
29+
-}
30+
31+
-- 多項式は
32+
-- U.fromList [a,b,c,...,z] = a + b * X + c * X^2 + ... + z * X^(k-1)
33+
-- により表す。
34+
35+
-- 多項式を X^k - X^(k-1) - ... - X - 1 で割った余りを返す。
36+
reduceM :: Int -> PM s -> ST s (PM s)
37+
reduceM !k !v = loop (UM.length v)
38+
where loop !l | l <= k = return (UM.take l v)
39+
| otherwise = do b <- UM.read v (l - 1)
40+
forM_ [l - k - 1 .. l - 2] $ \i -> do
41+
UM.modify v (+ b) i
42+
loop (l - 1)
43+
44+
-- 多項式の積を X^k - X^(k-1) - ... - X - 1 で割った余りを返す。
45+
mulP :: Int -> P -> P -> P
46+
mulP !k !v !w = {- U.force $ -} U.create $ do
47+
let !vl = U.length v
48+
!wl = U.length w
49+
-- s <- UM.new (vl + wl - 1)
50+
-- forM_ [0 .. vl + wl - 2] $ \i -> do
51+
-- let !x = sum' [(v U.! (i-j)) * (w U.! j) | j <- [max 0 (i - vl + 1) .. min (wl - 1) i]]
52+
-- UM.write s i x
53+
let n = ceiling ((log (fromIntegral (vl .|. wl)) :: Double) / log 2) :: Int
54+
s <- U.thaw (doMulP (2^n) v w)
55+
reduceM k s
56+
57+
-- 多項式に X をかけたものを X^k - X^(k-1) - ... - X - 1 で割った余りを返す。
58+
mulByX :: Int -> P -> P
59+
mulByX !k !v
60+
| U.length v == k = let !v_k = v U.! (k-1)
61+
in U.generate k $ \i -> if i == 0 then
62+
v_k
63+
else
64+
v_k + (v U.! (i - 1))
65+
| otherwise = U.cons 0 v
66+
67+
-- X の(mod X^k - X^(k-1) - ... - X - 1 での)n 乗
68+
powX :: Int -> Int -> P
69+
powX !k !n = doPowX n
70+
where
71+
doPowX 0 = U.fromList [1] -- 1
72+
doPowX 1 = U.fromList [0,1] -- X
73+
doPowX i = case i `quotRem` 2 of
74+
(j,0) -> let !f = doPowX j -- X^j mod P
75+
in mulP k f f
76+
(j,_) -> let !f = doPowX j -- X^j mod P
77+
in mulByX k (mulP k f f)
78+
79+
main :: IO ()
80+
main = do
81+
[k,n] <- unfoldr (BS.readInt . BS.dropWhile isSpace) <$> BS.getLine
82+
-- 2 <= k <= 1000
83+
-- 1 <= n <= 10^9
84+
if n <= k then
85+
print 1
86+
else do
87+
let f = powX k (n - k) -- X^(n-k) mod X^k - X^(k-1) - ... - X - 1
88+
let seq = replicate k 1 ++ map (sum . take k) (tails seq) -- 数列
89+
print $ sum $ zipWith (*) (U.toList f) (drop (k-1) seq)
90+
91+
--
92+
-- Modular Arithmetic
93+
--
94+
95+
newtype IntMod (m :: Nat) = IntMod { unwrapN :: Int64 } deriving (Eq)
96+
97+
instance Show (IntMod m) where
98+
show (IntMod x) = show x
99+
100+
instance KnownNat m => Num (IntMod m) where
101+
t@(IntMod x) + IntMod y
102+
| x + y >= modulus = IntMod (x + y - modulus)
103+
| otherwise = IntMod (x + y)
104+
where modulus = fromIntegral (natVal t)
105+
t@(IntMod x) - IntMod y
106+
| x >= y = IntMod (x - y)
107+
| otherwise = IntMod (x - y + modulus)
108+
where modulus = fromIntegral (natVal t)
109+
t@(IntMod x) * IntMod y = IntMod ((x * y) `rem` modulus)
110+
where modulus = fromIntegral (natVal t)
111+
fromInteger n = let result = IntMod (fromInteger (n `mod` fromIntegral modulus))
112+
modulus = natVal result
113+
in result
114+
abs = undefined; signum = undefined
115+
{-# SPECIALIZE instance Num (IntMod 1000000007) #-}
116+
117+
fromIntegral_Int64_IntMod :: KnownNat m => Int64 -> IntMod m
118+
fromIntegral_Int64_IntMod n = result
119+
where
120+
result | 0 <= n && n < modulus = IntMod n
121+
| otherwise = IntMod (n `mod` modulus)
122+
modulus = fromIntegral (natVal result)
123+
124+
{-# RULES
125+
"fromIntegral/Int->IntMod" fromIntegral = fromIntegral_Int64_IntMod . (fromIntegral :: Int -> Int64) :: Int -> IntMod (10^9 + 7)
126+
"fromIntegral/Int64->IntMod" fromIntegral = fromIntegral_Int64_IntMod :: Int64 -> IntMod (10^9 + 7)
127+
#-}
128+
129+
instance U.Unboxable (IntMod m) where
130+
type Rep (IntMod m) = Int64
131+
132+
--
133+
-- Univariate polynomial
134+
--
135+
136+
newtype Poly vec a = Poly { coeffAsc :: vec a } deriving Eq
137+
138+
normalizePoly :: (Eq a, Num a, G.Vector vec a) => vec a -> vec a
139+
normalizePoly v | G.null v || G.last v /= 0 = v
140+
| otherwise = normalizePoly (G.init v)
141+
142+
addPoly :: (Eq a, Num a, G.Vector vec a) => vec a -> vec a -> vec a
143+
addPoly v w = case compare n m of
144+
LT -> G.generate m $ \i -> if i < n
145+
then v G.! i + w G.! i
146+
else w G.! i
147+
GT -> G.generate n $ \i -> if i < m
148+
then v G.! i + w G.! i
149+
else v G.! i
150+
EQ -> normalizePoly $ G.zipWith (+) v w
151+
where n = G.length v
152+
m = G.length w
153+
154+
subPoly :: (Eq a, Num a, G.Vector vec a) => vec a -> vec a -> vec a
155+
subPoly v w = case compare n m of
156+
LT -> G.generate m $ \i -> if i < n
157+
then v G.! i - w G.! i
158+
else negate (w G.! i)
159+
GT -> G.generate n $ \i -> if i < m
160+
then v G.! i - w G.! i
161+
else v G.! i
162+
EQ -> normalizePoly $ G.zipWith (-) v w
163+
where n = G.length v
164+
m = G.length w
165+
166+
naiveMulPoly :: (Num a, G.Vector vec a) => vec a -> vec a -> vec a
167+
naiveMulPoly v w = G.generate (n + m - 1) $
168+
\i -> sum [(v G.! (i-j)) * (w G.! j) | j <- [max (i-n+1) 0..min i (m-1)]]
169+
where n = G.length v
170+
m = G.length w
171+
172+
doMulP :: (Eq a, Num a, G.Vector vec a) => Int -> vec a -> vec a -> vec a
173+
doMulP n !v !w | n <= 16 = naiveMulPoly v w
174+
doMulP n !v !w
175+
| G.null v = v
176+
| G.null w = w
177+
| G.length v < n2 = let (w0, w1) = G.splitAt n2 w
178+
u0 = doMulP n2 v w0
179+
u1 = doMulP n2 v w1
180+
in G.generate (G.length v + G.length w - 1)
181+
$ \i -> case () of
182+
_ | i < n2 -> u0 `at` i
183+
| i < n -> (u0 `at` i) + (u1 `at` (i - n2))
184+
| i < n + n2 -> (u1 `at` (i - n2))
185+
| G.length w < n2 = let (v0, v1) = G.splitAt n2 v
186+
u0 = doMulP n2 v0 w
187+
u1 = doMulP n2 v1 w
188+
in G.generate (G.length v + G.length w - 1)
189+
$ \i -> case () of
190+
_ | i < n2 -> u0 `at` i
191+
| i < n -> (u0 `at` i) + (u1 `at` (i - n2))
192+
| i < n + n2 -> (u1 `at` (i - n2))
193+
| otherwise = let (v0, v1) = G.splitAt n2 v
194+
(w0, w1) = G.splitAt n2 w
195+
v0_1 = v0 `addPoly` v1
196+
w0_1 = w0 `addPoly` w1
197+
p = doMulP n2 v0_1 w0_1
198+
q = doMulP n2 v0 w0
199+
r = doMulP n2 v1 w1
200+
-- s = (p `subPoly` q) `subPoly` r -- p - q - r
201+
-- q + s*X^n2 + r*X^n
202+
in G.generate (G.length v + G.length w - 1)
203+
$ \i -> case () of
204+
_ | i < n2 -> q `at` i
205+
| i < n -> ((q `at` i) + (p `at` (i - n2))) - ((q `at` (i - n2)) + (r `at` (i - n2)))
206+
| i < n + n2 -> ((r `at` (i - n)) + (p `at` (i - n2))) - ((q `at` (i - n2)) + (r `at` (i - n2)))
207+
| otherwise -> r `at` (i - n)
208+
where n2 = n `quot` 2
209+
at :: (Num a, G.Vector vec a) => vec a -> Int -> a
210+
at v i = if i < G.length v then v G.! i else 0
211+
{-# INLINE doMulP #-}
212+
213+
mulPoly :: (Eq a, Num a, G.Vector vec a) => vec a -> vec a -> vec a
214+
mulPoly !v !w = let k = ceiling ((log (fromIntegral (max n m)) :: Double) / log 2) :: Int
215+
in doMulP (2^k) v w
216+
where n = G.length v
217+
m = G.length w
218+
{-# INLINE mulPoly #-}
219+
220+
zeroPoly :: (G.Vector vec a) => Poly vec a
221+
zeroPoly = Poly G.empty
222+
223+
constPoly :: (Eq a, Num a, G.Vector vec a) => a -> Poly vec a
224+
constPoly 0 = Poly G.empty
225+
constPoly x = Poly (G.singleton x)
226+
227+
scalePoly :: (Eq a, Num a, G.Vector vec a) => a -> Poly vec a -> Poly vec a
228+
scalePoly a (Poly xs)
229+
| a == 0 = zeroPoly
230+
| otherwise = Poly $ G.map (* a) xs
231+
232+
valueAtPoly :: (Num a, G.Vector vec a) => Poly vec a -> a -> a
233+
valueAtPoly (Poly xs) t = G.foldr' (\a b -> a + t * b) 0 xs
234+
235+
instance (Eq a, Num a, G.Vector vec a) => Num (Poly vec a) where
236+
(+) = coerce (addPoly :: vec a -> vec a -> vec a)
237+
(-) = coerce (subPoly :: vec a -> vec a -> vec a)
238+
negate (Poly v) = Poly (G.map negate v)
239+
(*) = coerce (mulPoly :: vec a -> vec a -> vec a)
240+
fromInteger = constPoly . fromInteger
241+
abs = undefined; signum = undefined
242+
243+
divModPoly :: (Eq a, Fractional a, G.Vector vec a) => Poly vec a -> Poly vec a -> (Poly vec a, Poly vec a)
244+
divModPoly f g@(Poly w)
245+
| G.null w = error "divModPoly: divide by zero"
246+
| degree f < degree g = (zeroPoly, f)
247+
| otherwise = loop zeroPoly (scalePoly (recip b) f)
248+
where
249+
g' = toMonic g
250+
b = leadingCoefficient g
251+
-- invariant: f == q * g + scalePoly b p
252+
loop q p | degree p < degree g = (q, scalePoly b p)
253+
| otherwise = let q' = Poly (G.drop (degree' g) (coeffAsc p))
254+
in loop (q + q') (p - q' * g')
255+
256+
toMonic :: (Fractional a, G.Vector vec a) => Poly vec a -> Poly vec a
257+
toMonic f@(Poly xs)
258+
| G.null xs = zeroPoly
259+
| otherwise = Poly $ G.map (* recip (leadingCoefficient f)) xs
260+
261+
leadingCoefficient :: (Num a, G.Vector vec a) => Poly vec a -> a
262+
leadingCoefficient (Poly xs)
263+
| G.null xs = 0
264+
| otherwise = G.last xs
265+
266+
degree :: G.Vector vec a => Poly vec a -> Maybe Int
267+
degree (Poly xs) = case G.length xs - 1 of
268+
-1 -> Nothing
269+
n -> Just n
270+
271+
degree' :: G.Vector vec a => Poly vec a -> Int
272+
degree' (Poly xs) = case G.length xs of
273+
0 -> error "degree': zero polynomial"
274+
n -> n - 1
275+
276+
-- 組立除法
277+
-- second constPoly (divModByDeg1 f t) = divMod f (Poly (G.fromList [-t, 1]))
278+
divModByDeg1 :: (Eq a, Num a, G.Vector vec a) => Poly vec a -> a -> (Poly vec a, a)
279+
divModByDeg1 f t = let w = G.postscanr (\a b -> a + b * t) 0 $ coeffAsc f
280+
in (Poly (G.tail w), G.head w)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /