type id = int

type mltypes = 
    (* concrete types *)
    Tint 
  | Tbool 
  | Tlist of mltypes
  | Tpair of mltypes * mltypes
  | Tarrow of mltypes * mltypes
    (* intermediate types for type synthesis *)
  | TVar of vartype
  | TUnknown
and vartype = { id : int ; mutable v : mltypes }

(* type schema *)
type schema = Forall of id list * mltypes

(* type environment *)
type typeenv = ( string * schema ) list

open MiniML

let new_id = 
  let v = ref (-1) in
  fun () -> (incr v; !v : id)

let new_vartype t = TVar { id = new_id (); v = t }

let occurs { id = n; v = _ } v2 =
  let rec iter = function
      TVar { id = m; v = _ } -> (n = m)
    | Tint | Tbool -> false
    | Tarrow(t1,t2) -> iter t1 || iter t2
    | Tpair(t1,t2) -> iter t1 || iter t2
    | Tlist(t1) -> iter t1
    | TUnknown -> failwith "occurs"
  in iter v2;;

let rec shorten t = match t with
  TVar ({ id = _; v = tv } as tv2) -> begin
    match tv with
      TUnknown -> t
    | TVar { id = _; v = TUnknown } -> tv
    | TVar tv1 -> tv2.v <- tv1.v; shorten t
    | t' -> shorten t'
  end
| TUnknown -> failwith "shorten: should not happen"
| t -> t

let vars_of_type t = 
  let rec iter vs = function
      TVar { id = n; v = TUnknown } ->
	if List.mem n vs then vs else n::vs
    | TVar { id = _; v = t } -> iter vs t
    | Tarrow(t1, t2) -> iter (iter vs t1) t2
    | Tpair(t1, t2) -> iter (iter vs t1) t2
    | Tint | Tbool -> vs
    | Tlist t2 -> iter vs t2
    | TUnknown -> failwith "vars_of_type: should not happen"
  in iter [] t;;

let rec subtract_list a b =
  match a with
    [] -> []
  | hd::tl ->
      if List.mem hd b
      then subtract_list tl b
      else hd::(subtract_list tl b)

let rec unique = function
    [] -> []
  | hd::tl ->
      if List.mem hd tl 
      then unique tl
      else hd::(unique tl)

let vars_of_tenv tenv = 
  List.flatten 
    (List.map 
       (fun (_, Forall(gvars, t)) -> 
	 subtract_list (vars_of_type t) gvars)
       tenv)

let generalize tenv t =
  let genvars =
    unique (subtract_list (vars_of_type t)
	      (vars_of_tenv tenv))
  in Forall(genvars, t)

let instantiate = function
    Forall([], t) -> t
  | Forall(gvars, t) ->
      let new_vars = 
	List.map(fun n -> (n, new_vartype TUnknown)) gvars
      in
      let rec iter = function
	  (TVar { id = n; v = TUnknown } as t) ->
	    (try List.assoc n new_vars
	    with _ -> t)
	| TVar { id = _; v = t } -> iter t
	| Tint -> Tint
	| Tbool -> Tbool
	| Tlist(t1) -> Tlist(iter t1)
	| Tarrow(t1,t2) -> Tarrow(iter t1, iter t2)
	| Tpair(t1,t2) -> Tpair(iter t1, iter t2)
	| TUnknown -> failwith "instantiate: should not happen"
      in
      iter t

let rec unify t1 t2 = match (shorten t1, shorten t2) with
  (TVar ({ id = n; v = TUnknown } as tv1),
   (TVar { id = m; v = TUnknown } as t2))
  -> if n != m then tv1.v <- t2 else ()
| (t1, TVar ({ id = _; v = TUnknown } as tv2))
  -> if not (occurs tv2 t1) then tv2.v <- t1
    else failwith "unify: recursive type occured"
| (TVar ({ id = _; v = TUnknown } as tv1), t2)
  -> if not (occurs tv1 t2) then tv1.v <- t2
  else failwith "unify: recursive type occured"
| (Tint, Tint) -> ()
| (Tbool, Tbool) -> ()
| (Tarrow (t1, t2), Tarrow(t'1, t'2))
  -> unify t1 t'1; unify t2 t'2;
| (Tpair (t1, t2), Tpair(t'1, t'2))
  -> unify t1 t'1; unify t2 t'2;
| (Tlist t1, Tlist t2) -> unify t1 t2;
| _ -> failwith "unify: type does not match"

let rec type_expr tenv = function 
    Const (Int _) -> Tint
  | Const (Bool _) -> Tbool
  | Const (Nil) -> 
      Tlist(new_vartype TUnknown)
  | Const (_) -> failwith "type_expr: should not happen"
  | Plus (e1, e2) ->
      unify (type_expr tenv e1) Tint;
      unify (type_expr tenv e2) Tint; Tint
  | Minus (e1, e2) ->
      unify (type_expr tenv e1) Tint;
      unify (type_expr tenv e2) Tint; Tint
  | Times (e1, e2) ->
      unify (type_expr tenv e1) Tint;
      unify (type_expr tenv e2) Tint; Tint
  | Div (e1, e2) ->
      unify (type_expr tenv e1) Tint;
      unify (type_expr tenv e2) Tint; Tint
  | Equal (e1, e2) ->
      failwith "exercise1"
  | ConsExp (e1, e2) -> 
      failwith "exercise1"
  | PairExp (e1, e2) -> 
      Tpair (type_expr tenv e1, type_expr tenv e2)
  | IfExp (e1, e2, e3) ->
      unify (type_expr tenv e1) Tbool;
      let u = new_vartype TUnknown in
      unify (type_expr tenv e2) u;
      unify (type_expr tenv e3) u;
      u
  | Var s -> 
      let tscheme = try
	List.assoc s tenv
      with _ -> failwith ("Unbound variable: " ^ s)
      in instantiate tscheme
  | LambdaExp([IdentPtn x, e]) ->
      let alpha = new_vartype TUnknown in
      let ts = Forall([], alpha) in
      Tarrow(alpha, type_expr ((x, ts) :: tenv) e)
  | App(e1, e2) ->
      failwith "exercise2-1"
  | LetExp([IdentPtn s,e1],e2) ->
      let t1 = type_expr tenv e1 in
      let ts = generalize tenv t1 in
      type_expr ((s,ts)::tenv) e2
  | LetRecExp([IdentPtn s,e1],e2) ->
      failwith "exercise2-2"
(* extensions *)
  | LetExp(ptns, e2) ->
      failwith "Unimplemented (optional excesise 4-1)"
  | LetRecExp(ptns, e2) ->
      failwith "Unimplemented (optional excesise 4-2)"
  | LambdaExp(ptns) ->
      failwith "Unimplemented (optional excesise 3)"
  | MatchExp(e1,ptns) ->
      failwith "Unimplemented (optional excesise 3)"
  | TopLetExp _ | TopLetRecExp _ -> 
      failwith "type_expr: top level phrase not supported"

and pattern_type = function
    (* for optional excesises *)
    ConstPtn v -> (type_expr [] (Const v), [])
  | IdentPtn "_" -> 
      (new_vartype TUnknown, [])
  | IdentPtn s -> 
      let alpha = new_vartype TUnknown in
      (alpha, [s, alpha])
  | ConsPtn(v1, v2) ->
      let (t1, b1) = pattern_type v1 in
      let (t2, b2) = pattern_type v2 in
      unify t2 (Tlist(t1));
      (t2, b1 @ b2)
  | PairPtn(v1, v2) ->
      let (t1, b1) = pattern_type v1 in
      let (t2, b2) = pattern_type v2 in
      (Tpair(t1, t2), b1 @ b2)

(* for trial *)
let rec shorten_all t = match t with
  TVar ({ id = _; v = TUnknown }) -> t
| TVar tv -> shorten_all tv.v
| TUnknown -> failwith "shorten: should not happen"
| Tarrow(t1, t2) -> Tarrow(shorten_all t1, shorten_all t2)
| Tlist(t1) -> Tlist(shorten_all t1)
| t -> t

let test_tenv = 
  let alpha = TVar { id = -1; v = TUnknown } in
  let beta = TVar { id = -2; v = TUnknown } in
  let hd_typescheme =
    Forall([-1], Tarrow(Tlist alpha, alpha)) in
  let tl_typescheme =
    Forall([-1], Tarrow(Tlist alpha, Tlist alpha)) in
  let fst_typescheme =
    Forall([-2; -1], Tarrow(Tpair(alpha, beta), alpha)) in
  let snd_typescheme =
    Forall([-2; -1], Tarrow(Tpair(alpha, beta), beta)) in
  ["hd", hd_typescheme; 
   "tl", tl_typescheme;
   "fst", fst_typescheme;
   "snd", snd_typescheme;
 ];;

let type_of_expr e = shorten_all (type_expr test_tenv e);;

(* you can use this as blackbox *)
open Format
let print_mltypes t = 
  let n = ref (-1) in
  let vars = ref [] in
  let rec iter level ppf = 
    function
	Tint -> fprintf ppf "int"
      |	Tbool -> fprintf ppf "bool"
      |	Tlist x -> fprintf ppf "@[%a list@]" (iter 2) x
      |	Tpair(x,y) as z ->
	  if level >= 2 then
	    fprintf ppf "(@[%a@])" (iter 0) z
	  else fprintf ppf "@[<1>%a@ * %a@]" (iter 2) x (iter 2) y
      |	Tarrow(x,y) as z -> 
	  if level >= 1 then
	    fprintf ppf "(@[%a@])" (iter 0) z
	  else
	    fprintf ppf "@[<1>%a ->@ %a@]" (iter 1) x (iter 0) y
      |	TVar { id = id; v = TUnknown } ->
	  let varnum = 
	    try List.assoc id !vars
	    with _ -> (incr n; 
		       vars := (id, !n)::!vars; 
		       !n)
	  in
	  fprintf ppf "'%c" (char_of_int (int_of_char 'a' + (varnum mod 26)));
	  if (varnum > 26) then
	    fprintf ppf "%d" (varnum / 26)
      |	TVar { id = _; v = v } -> iter level ppf v
      |	TUnknown -> failwith "print_mltypes: should not happen"
  in
  printf "%a" (iter 0) (shorten_all t)

(*

#load "miniML.cmo";;
#load "miniMLLexer.cmo";;
#load "miniMLParser.cmo";;
#load "miniMLReader.cmo";;
open MiniMLReader;;

(* if you have compiled this file *)
#load "miniMLTyping.cmo";;
open MiniMLTyping;;

(* if you like formatted output like "'a -> 'a" *)
#install_printer print_mltypes;;

*)
