|
| 1 | +{-# LANGUAGE BangPatterns #-} |
| 2 | +{-# LANGUAGE MultiParamTypeClasses #-} |
| 3 | +{-# LANGUAGE TypeFamilies #-} |
| 4 | +{-# LANGUAGE DataKinds #-} |
| 5 | +{-# LANGUAGE TypeOperators #-} |
| 6 | +{-# LANGUAGE NoStarIsType #-} |
| 7 | +module ModularArithmetic_TypeNats where |
| 8 | +import Data.Int |
| 9 | +import GHC.TypeNats (Nat, KnownNat, natVal, type (^), type (+)) |
| 10 | + |
| 11 | +-- |
| 12 | +-- Modular Arithmetic |
| 13 | +-- |
| 14 | + |
| 15 | +-- type N = IntMod (10^9 + 7) |
| 16 | + |
| 17 | +newtype IntMod (m :: Nat) = IntMod { unwrapN :: Int64 } deriving (Eq) |
| 18 | + |
| 19 | +instance Show (IntMod m) where |
| 20 | + show (IntMod x) = show x |
| 21 | + |
| 22 | +instance KnownNat m => Num (IntMod m) where |
| 23 | + t@(IntMod x) + IntMod y |
| 24 | + | x + y >= modulus = IntMod (x + y - modulus) |
| 25 | + | otherwise = IntMod (x + y) |
| 26 | + where modulus = fromIntegral (natVal t) |
| 27 | + t@(IntMod x) - IntMod y |
| 28 | + | x >= y = IntMod (x - y) |
| 29 | + | otherwise = IntMod (x - y + modulus) |
| 30 | + where modulus = fromIntegral (natVal t) |
| 31 | + t@(IntMod x) * IntMod y = IntMod ((x * y) `rem` modulus) |
| 32 | + where modulus = fromIntegral (natVal t) |
| 33 | + fromInteger n = let result = IntMod (fromInteger (n `mod` fromIntegral modulus)) |
| 34 | + modulus = natVal result |
| 35 | + in result |
| 36 | + abs = undefined; signum = undefined |
| 37 | + |
| 38 | +{-# RULES |
| 39 | +"^9/Int" forall x. x ^ (9 :: Int) = let u = x; v = u * u * u in v * v * v |
| 40 | +"^9/Integer" forall x. x ^ (9 :: Integer) = let u = x; v = u * u * u in v * v * v |
| 41 | + #-} |
| 42 | + |
| 43 | +fromIntegral_Int64_IntMod :: KnownNat m => Int64 -> IntMod m |
| 44 | +fromIntegral_Int64_IntMod n = result |
| 45 | + where |
| 46 | + result | 0 <= n && n < modulus = IntMod n |
| 47 | + | otherwise = IntMod (n `mod` modulus) |
| 48 | + modulus = fromIntegral (natVal result) |
| 49 | + |
| 50 | +{-# RULES |
| 51 | +"fromIntegral/Int->IntMod" fromIntegral = fromIntegral_Int64_IntMod . (fromIntegral :: Int -> Int64) :: Int -> IntMod (10^9 + 7) |
| 52 | +"fromIntegral/Int64->IntMod" fromIntegral = fromIntegral_Int64_IntMod :: Int64 -> IntMod (10^9 + 7) |
| 53 | + #-} |
| 54 | + |
| 55 | +--- |
| 56 | + |
| 57 | +exEuclid :: (Eq a, Integral a) => a -> a -> (a, a, a) |
| 58 | +exEuclid !f !g = loop 1 0 0 1 f g |
| 59 | + where loop !u0 !u1 !v0 !v1 !f 0 = (f, u0, v0) |
| 60 | + loop !u0 !u1 !v0 !v1 !f g = |
| 61 | + case divMod f g of |
| 62 | + (q,r) -> loop u1 (u0 - q * u1) v1 (v0 - q * v1) g r |
| 63 | + |
| 64 | +instance KnownNat m => Fractional (IntMod m) where |
| 65 | + recip t@(IntMod x) = IntMod $ case exEuclid x modulus of |
| 66 | + (1,a,_) -> a `mod` modulus |
| 67 | + (-1,a,_) -> (-a) `mod` modulus |
| 68 | + _ -> error "not invertible" |
| 69 | + where modulus = fromIntegral (natVal t) |
| 70 | + fromRational = undefined |
0 commit comments