-- Copyright (c) 2006-2011, David Amos. All rights reserved.

{-# LANGUAGE NoMonomorphismRestriction #-}

-- |A module providing functions to test for primality, and find next and previous primes.
module Math.NumberTheory.Prime (primes, isTrialDivisionPrime, isMillerRabinPrime,
                                isPrime, notPrime, prevPrime, nextPrime) where

import System.Random
import System.IO.Unsafe


isTrialDivisionPrime :: Integer -> Bool
isTrialDivisionPrime n :: Integer
n
    | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> 1 = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\p :: Integer
p -> Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0) ((Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (\p :: Integer
p -> Integer
pInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
p Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n) [Integer]
primes)
    | Bool
otherwise = Bool
False

-- |A (lazy) list of the primes
primes :: [Integer]
primes :: [Integer]
primes = 2 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
filter Integer -> Bool
isPrime [3,5..] where
    isPrime :: Integer -> Bool
isPrime n :: Integer
n = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\p :: Integer
p -> Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0) ((Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (\p :: Integer
p -> Integer
pInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
p Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n) [Integer]
primes)

{-
-- This is just marginally faster, but less elegant
primes2 :: [Integer]
primes2 = 2 : 3 : 5 : 7 : filter isPrime
    (concat [ [m30+11,m30+13,m30+17,m30+19,m30+23,m30+29,m30+31,m30+37] | m30 <- [0,30..] ])
    where isPrime n = not $ any (\p -> n `rem` p == 0) (takeWhile (\p -> p*p <= n) primes2')
          primes2' = drop 3 primes2
-}

{-
-- initial version. This isn't going to be very good if n has any "large" prime factors (eg > 10000)
pfactors1 n | n > 0 = pfactors' n primes
            | n < 0 = -1 : pfactors' (-n) primes
    where pfactors' n (d:ds) | n == 1 = []
                             | n < d*d = [n]
                             | r == 0 = d : pfactors' q (d:ds)
                             | otherwise = pfactors' n ds
                             where (q,r) = quotRem n d
-}

-- MILLER-RABIN TEST
-- Cohen, A Course in Computational Algebraic Number Theory, p422
-- Koblitz, A Course in Number Theory and Cryptography


-- Let n-1 = 2^s * q, q odd
-- Then n is a strong pseudoprime to base b if
-- either b^q == 1 (mod n)
-- or b^(2^r * q) == -1 (mod n) for some 0 <= r < s
-- (For we know that if n is prime, then b^(n-1) == 1 (mod n)

isStrongPseudoPrime :: b -> b -> Bool
isStrongPseudoPrime n :: b
n b :: b
b =
    let (s :: a
s,q :: b
q) = a -> b -> (a, b)
forall t a. (Num a, Integral t) => a -> t -> (a, t)
split2s 0 (b
nb -> b -> b
forall a. Num a => a -> a -> a
-1)  -- n-1 == 2^s * q, with q odd
    in b -> (Int, b) -> b -> Bool
forall a b. (Integral b, Integral a) => a -> (Int, b) -> a -> Bool
isStrongPseudoPrime' b
n (Int
forall a. Num a => a
s,b
q) b
b

isStrongPseudoPrime' :: a -> (Int, b) -> a -> Bool
isStrongPseudoPrime' n :: a
n (s :: Int
s,q :: b
q) b :: a
b
    | a
b' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = Bool
True
    | Bool
otherwise = a
na -> a -> a
forall a. Num a => a -> a -> a
-1 a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [a]
squarings
    where b' :: a
b' = a -> b -> a -> a
forall t a. (Integral t, Integral a) => a -> t -> a -> a
power_mod a
b b
q a
n     -- b' = b^q `mod` n
          squarings :: [a]
squarings = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
s ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate (\x :: a
x -> a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Integral a => a -> a -> a
`mod` a
n) a
b' -- b^(2^r *q) for 0 <= r < s

-- split2s 0 m returns (s,t) such that 2^s * t == m, t odd
split2s :: a -> t -> (a, t)
split2s s :: a
s t :: t
t = let (q :: t
q,r :: t
r) = t
t t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
`quotRem` 2
              in if t
r t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then a -> t -> (a, t)
split2s (a
sa -> a -> a
forall a. Num a => a -> a -> a
+1) t
q else (a
s,t
t)

-- power_mod b t n == b^t mod n
power_mod :: a -> t -> a -> a
power_mod b :: a
b t :: t
t n :: a
n = a -> a -> t -> a
forall t. Integral t => a -> a -> t -> a
powerMod' a
b 1 t
t
    where powerMod' :: a -> a -> t -> a
powerMod' x :: a
x y :: a
y 0 = a
y
          powerMod' x :: a
x y :: a
y t :: t
t = let (q :: t
q,r :: t
r) = t
t t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
`quotRem` 2
                            in a -> a -> t -> a
powerMod' (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
n) (if t
r t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then a
y else a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
y a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
n) t
q

isMillerRabinPrime' :: a -> IO Bool
isMillerRabinPrime' n :: a
n
    | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= 4 =
        let (s :: a
s,q :: a
q) = a -> a -> (a, a)
forall t a. (Num a, Integral t) => a -> t -> (a, t)
split2s 0 (a
na -> a -> a
forall a. Num a => a -> a -> a
-1) -- n-1 == 2^s * q, with q odd
        in do StdGen
g <- IO StdGen
getStdGen
              let rs :: [a]
rs = (a, a) -> StdGen -> [a]
forall a g. (Random a, RandomGen g) => (a, a) -> g -> [a]
randomRs (2,a
na -> a -> a
forall a. Num a => a -> a -> a
-1) StdGen
g
              Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> (Int, a) -> a -> Bool
forall a b. (Integral b, Integral a) => a -> (Int, b) -> a -> Bool
isStrongPseudoPrime' a
n (Int
forall a. Num a => a
s,a
q)) (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take 25 [a]
rs)
    | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= 2 = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    | Bool
otherwise = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
-- Cohen states that if we restrict our rs to single word numbers, we can use a more efficient powering algorithm

-- isMillerRabinPrime :: Integer -> Bool
isMillerRabinPrime :: a -> Bool
isMillerRabinPrime n :: a
n = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (a -> IO Bool
forall a. (Integral a, Random a) => a -> IO Bool
isMillerRabinPrime' a
n)


-- |Is this number prime? The algorithm consists of using trial division to test for very small factors,
-- followed if necessary by the Miller-Rabin probabilistic test.
isPrime :: Integer -> Bool
isPrime :: Integer -> Bool
isPrime n :: Integer
n | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> 1 = [Integer] -> Bool
isPrime' ([Integer] -> Bool) -> [Integer] -> Bool
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 100) [Integer]
primes
          | Bool
otherwise = Bool
False
    where isPrime' :: [Integer] -> Bool
isPrime' (d :: Integer
d:ds :: [Integer]
ds) | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
dInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
d = Bool
True
                          | Bool
otherwise = let (q :: Integer
q,r :: Integer
r) = Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
quotRem Integer
n Integer
d
                                        in if Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then Bool
False else [Integer] -> Bool
isPrime' [Integer]
ds
          isPrime' [] = Integer -> Bool
forall a. (Integral a, Random a) => a -> Bool
isMillerRabinPrime Integer
n
-- the < 100 is found heuristically to be about the point at which trial division stops being worthwhile

notPrime :: Integer -> Bool
notPrime :: Integer -> Bool
notPrime = Bool -> Bool
not (Bool -> Bool) -> (Integer -> Bool) -> Integer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Bool
isPrime

-- |Given n, @prevPrime n@ returns the greatest p, p < n, such that p is prime
prevPrime :: Integer -> Integer
prevPrime :: Integer -> Integer
prevPrime n :: Integer
n | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> 5 = [Integer] -> Integer
forall a. [a] -> a
head ([Integer] -> Integer) -> [Integer] -> Integer
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
filter Integer -> Bool
isPrime ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [Integer]
candidates
            | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 3 = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error "prevPrime: no previous primes"
            | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 3 = 2
            | Bool
otherwise = 3
            where n6 :: Integer
n6 = (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` 6) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* 6
                  candidates :: [Integer]
candidates = (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n) ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> [Integer]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ [Integer
m6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+5,Integer
m6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+1] | Integer
m6 <- [Integer
n6, Integer
n6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
-6..] ]

-- |Given n, @nextPrime n@ returns the least p, p > n, such that p is prime
nextPrime :: Integer -> Integer
nextPrime :: Integer -> Integer
nextPrime n :: Integer
n | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2 = 2
            | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 3 = 3
            | Bool
otherwise = [Integer] -> Integer
forall a. [a] -> a
head ([Integer] -> Integer) -> [Integer] -> Integer
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
filter Integer -> Bool
isPrime ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [Integer]
candidates
            where n6 :: Integer
n6 = (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` 6) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* 6
                  candidates :: [Integer]
candidates = (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n) ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> [Integer]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ [Integer
m6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+1,Integer
m6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+5] | Integer
m6 <- [Integer
n6, Integer
n6Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+6..] ]

{-
-- slightly better version. This is okay so long as n has at most one "large" prime factor (> 10000)
-- if it has more, it does at least tell you, via an error message, that it has run into difficulties
pfactors2 n | n > 0 = pfactors' n $ takeWhile (< 10000) primes
            | n < 0 = -1 : pfactors' (-n) (takeWhile (< 10000) primes)
    where pfactors' n (d:ds) | n == 1 = []
                             | n < d*d = [n]
                             | r == 0 = d : pfactors' q (d:ds)
                             | otherwise = pfactors' n ds
                             where (q,r) = quotRem n d
          pfactors' n [] = if isMillerRabinPrime n then [n] else error ("pfactors2: can't factor " ++ show n)
-}