(* ----------------------------------------------- *)
(* GNU GPL opensource, C.McCormack and L.Allison, *)
(* Monash U., Clayton, Australia, August 2006 *)
(* ----------------------------------------------- *)
datatype BigInt = Neg of (int list) | Pos of (int list);
(* NB all operations must guarantee NO (Neg [0]) ! *)
(* strictly speaking, Pos includes zero, i.e. all `non-negative' *)
val internalBase = 100; (* say *)
val zero = Pos [0];
fun invert LESS = GREATER
| invert EQUAL = EQUAL
| invert GREATER = LESS;
fun negate (Neg x) = Pos x (* < 0 *)
| negate (Pos [0]) = Pos [0] (* = 0 *)
| negate (Pos x) = Neg x; (* > 0 *)
fun lift f (Pos xs) = Pos (f xs) (* a well-known FP trick *)
| lift f (Neg xs) = Neg (f xs);
fun clean b = (* put b:BigInt into correct canonical form *)
let fun chop [] = [] (* chop off any most sig. zero-digits *)
| chop [0] = []
| chop (0::xs) = (case chop xs of [] => [] | ys => 0::ys)
| chop (x::xs) = x::chop xs;
val b' = lift chop b
in case b' of Neg [] => zero | Pos [] => zero | b' => b'
end;
(* ------------------------------------------------------------------------- *)
(* Make 'compare' first, then 'plus' & 'minus', then 'times', & finally IO *)
fun compare (Pos _) (Neg _) = GREATER
| compare (x as (Neg _)) y = invert(compare (negate x) (negate y))
| compare (Pos x) (Pos y) = (* linear-time but... *)
let val lx = length x (* ...could be slightly more efficient *)
and ly = length y;
fun f [] [] = EQUAL
| f (x::xs) (y::ys) =
if x < y then LESS
else if x > y then GREATER
else (* x = y *) f xs ys
in if lx < ly then LESS
else if lx > ly then GREATER
else (* lx = ly *) f (rev x) (rev y)
end;
fun plus (Pos ms) (Pos ns) =
let fun f c [] [] = [c]
| f c xs [] = f 0 [c] xs
| f c [] xs = f c xs []
| f c (x::xs) (y::ys) =
let val cxy = c+x+y;
val x' = cxy mod internalBase
and c' = cxy div internalBase;
in x' :: (f c' xs ys)
end
in clean(Pos (f 0 ms ns))
end
| plus (m as (Neg _)) (n as (Neg _)) = negate(plus (negate m) (negate n))
| plus (m as (Pos _)) (n as (Neg _)) = minus m (negate n) (* m+n=m-(-n) *)
| plus (m as (Neg _)) (n as (Pos _)) = plus n m (* m+n=n+m *)
and minus (m as (Pos ms)) (n as (Pos ns)) =
(case compare m n of
LESS => negate(minus n m) | (* m-n=-(n-m) *)
EQUAL => zero |
GREATER =>
let fun f 0 [] [] = []
| f b (m::ms) (n::ns) = (* b ~ borrow *)
let val bn = b+n
in if bn > m then (internalBase+m-bn)::(f 1 ms ns)
else (m-bn)::(f 0 ms ns)
end
| f b ms [] = f b ms [0]
in clean(Pos (f 0 ms ns))
end
)
| minus (m as (Pos _)) (n as (Neg _)) = plus m (negate n) (* m-n=m+(-n) *)
| minus (m as (Neg _)) (n as (Neg _)) =
minus (negate n) (negate m) (* m-n = -n-(-m) *)
| minus (m as (Neg _)) (n as (Pos _)) = negate(minus n m); (* m-n=-(n-m) *)
fun timesDigit _ 0 = zero
| timesDigit n 1 = n
| timesDigit (n as (Neg _)) d = negate(timesDigit (negate n) d)
| timesDigit (Pos n) d = (* NB. 1 < d < internalBase *)
let fun t 0 [] = []
| t c [] = [c]
| t c (n::ns) = (* c ~ carry, note d is a const for t *)
let val nd = n*d + c
in (nd mod internalBase)::(t (nd div internalBase) ns)
end
in Pos (t 0 n)
end;
fun times (m as (Pos _)) (n as (Pos ns)) =
let fun t [] = zero
| t (d::ds) = plus (timesDigit m d) (shift (t ds))
and shift (Pos ds) = Pos (0::ds)
in t ns
end
| times (m as (Neg _)) (n as (Neg _)) = times (negate m) (negate n)
| times (m as (Pos _)) (n as (Neg _)) = negate(times m (negate n))
| times (m as (Neg _)) (n as (Pos _)) = times n m;
(* ------------------------------------------------------------------------- *)
exception BadBase; (* for when the base is out of range *)
exception BadString; (* for an invalid string representation *)
fun isDecimalDigit x = x >= #"0" andalso x <= #"9";
fun isUpperDigit x = x >= #"A" andalso x <= #"Z";
fun isLowerDigit x = x >= #"a" andalso x <= #"z";
(* converts a char into its numeric digit value *)
fun digitVal x base =
let
val digit = (ord x) - (if isDecimalDigit x then ord #"0"
else if isUpperDigit x then ord #"A" - 10
else if isLowerDigit x then ord #"a" - 10
else raise BadString);
in
(* check that the digit is in range *)
if digit >= base then raise BadString
else digit
end;
(* converts a numeric digit value into the corresponding char *)
fun digitChar x = chr (x + (if x < 10 then ord #"0" else ord #"A" - 10));
fun fromString s =
let
(* converts the chars representing the base into the actual number *)
fun base' [] _ = raise BadString
| base' [#"]"] b = if b <= internalBase then b else raise BadBase
| base' (x::xs) b = base' xs (b * 10 + (digitVal x 10));
(* searches for the '[' character and passes the rest off to base',
or if there was no '[' it returns the default base, 10 *)
fun base [] = 10
| base (#"["::xs) = base' xs 0
| base (x::xs) = base xs;
(* converts from the base of the string to internalBase *)
fun digits' [] i _ = i
| digits' (#"["::_) i _ = i
| digits' (x::xs) i base =
let
(* 'i' is the accumulated value so far. for each digit we
compute i * base + digit and pass this in as the new
'i' value in the recursive call. *)
val j = times i (Pos [base]);
val k = plus j (Pos [digitVal x base]);
in
digits' xs k base
end;
fun digits L = digits' L zero (base L);
(* parses the sign character and passes the rest off to digits *)
fun sign (#"+"::xs) = digits xs
| sign (#"-"::xs) = negate (digits xs)
| sign (#"["::_) = raise BadString
| sign [] = raise BadString
| sign L = digits L;
in
sign (explode s)
end;
(* divides a BigInt by a single digit number (assumes n >= 0),
used by toString *)
fun divide m n =
let
(* converts a BigInt into a tuple of the sign (as a bool) and the
digit list *)
fun break (Pos ms) = (true, ms)
| break (Neg ms) = (false, ms);
(* converts a tuple of a sign (as a bool) and a digit list into
a BigInt *)
fun mend (true, ms) = Pos ms
| mend (false, ms) = Neg ms;
fun divide' [] c z = (z, c) (* base: no digits left, just return
the carry *)
| divide' (x::xs) c z = (* rec: some digits left... *)
let
val x' = x + c * internalBase; (* add the carry to the current
digit *)
val d = x' div n; (* do the division *)
val m = x' mod n; (* compute the new carry *)
in
divide' xs m (d::z)
end;
val (sign, ms) = break m; (* extract the sign and digits *)
val (quo, rem) = divide' (rev ms) 0 []; (* do the division, and get back
the quotient (as a list) and
remainder *)
in
(clean (mend (sign, quo)), rem) (* return the quotient as a BigInt,
and the remainder *)
end;
fun toString base m =
if base >= 2 andalso base <= 36
then let
(* repeatedly divides the BigInt 'm' by 'base', and builds a list
of the remainder digits *)
fun toString' (Pos [0]) ds = ds
| toString' m ds =
let
val (quo, rem) = divide m base
in
toString' quo ((digitChar rem)::ds)
end;
val digits = toString' m [];
val sstr = case m of Neg _ => "-" | _ => "";
val dstr = if null digits then "0" else implode digits;
val bstr = if base = 10 then "" else "[" ^ Int.toString base ^ "]";
in
sstr ^ dstr ^ bstr
end
else raise BadBase;
(* the internal representation returned looks just like the ML values,
e.g. "Pos [1,2,3]" *)
fun show m =
let
fun show'' [] = "]"
| show'' (x::xs) = "," ^ Int.toString x ^ show'' xs;
fun show' (x::xs) = Int.toString x ^ show'' xs;
in
case m of Pos ms => "Pos [" ^ show' ms
| Neg ms => "Neg [" ^ show' ms
end;
(* converts an int to a BigInt by repeated division, similarly to toString *)
fun fromInt 0 = zero
| fromInt x =
if x < 0 then negate (fromInt (~x))
else let
fun fromInt' 0 ds = ds
| fromInt' x ds = fromInt' (x div internalBase)
((x mod internalBase)::ds);
in
Pos (fromInt' x [])
end;
(* converts a BigInt to an int by repeated addition and multiplication by the
base *)
fun toInt m =
let
fun toInt' [] = 0
| toInt' (x::xs) = x + internalBase * (toInt' xs);
in
case m of Pos ms => toInt' ms
| Neg ms => ~(toInt' ms)
end;
(* ------------------------------------------------------------------------- *)