#include "share/atspre_staload.hats"
staload UN = "prelude/SATS/unsafe.sats"
(* ****** ****** ****** *)
macdef
freehom(h) = cloptr_free($UN.castvwtp0(,(h)))
vtypedef
hom(a: vt0ype, b: vt0ype) = a -<lincloptr1> b
macdef
ev(f, x) = y
where
{ val y = ,(f)(,(x))
val () = cloptr_free($UN.castvwtp0(,(f))) }
sortdef
fvtype = vt0ype -> vtype
dataprop
EQ_fvtype(fvtype, fvtype) = {f: fvtype} Eq_fvtype(f, f)
dataprop
EQ_vtype(vtype, vtype) = {z: vtype} Eq_vtype(z, z)
(* ****** ****** ****** *)
absprop
MONAD(m_name: type, m: fvtype)
extern fun{m_name: type}{a: vt0ype}
monad_return
{m: fvtype}
(MONAD(m_name, m)| a): m(a)
extern fun{m_name: type}{a, b: vt0ype}
monad_bind
{m: fvtype}
(MONAD(m_name, m)| m(a), hom(a, m(b))): m(b)
praxi{m_name: type}
trust_me_this_is_the_monad
{m: fvtype}.<m>.
(): MONAD(m_name, m) =
$UN.proof_assert{MONAD(m_name, m)}()
(* ****** ****** ****** *)
// Continuation monads
abstype
cont_r
stadef
cont(r: vt0ype) = lam(a: vt0ype) => hom(a, r) -<lincloptr1> r
extern praxi
MONAD_cont_elim
{m: fvtype}
(pfm: MONAD(cont_r, m)): [r: vt0ype] EQ_fvtype(cont r, m)
implement(a: vt0ype)
monad_return<cont_r><a>(pfm| x) =
let prval Eq_fvtype() = MONAD_cont_elim(pfm)
in llam(k) => ev(k, x) end
implement(a, b: vt0ype)
monad_bind<cont_r><a, b>(pfm| mx, fopr) =
let prval Eq_fvtype() = MONAD_cont_elim(pfm)
in
llam(k) => res
where
{ val bar = llam(x: a) =<cloptr1>
let val fxk = fopr(x)(k)
in freehom(fopr); fxk end
val res = mx(bar)
val () = freehom(mx) }
end
(* ****** ****** ****** *)
implement main0 () = let
prval pfcont = trust_me_this_is_the_monad<cont_r>{cont int}()
fun intsqrt_cps
(n: int): cont int (int) =
llam(k) =<cloptr1>
if n > 0 then
let val n4 = n / 4
in intsqrt_cps(n4)
( llam(r) => ev(k, r')
where
{ val r' = if (2*r+1)**2 <= n
then 2*r+1 else 2*r } )
end
else ev(k, 0)
val intsqrt_cps: int -<lincloptr1> cont int (int) =
llam(n) =<cloptr1> intsqrt_cps(n)
fun square_cps
(n: int): cont int (int) =
llam(k) =<cloptr1> ev(k, n*n)
val five_cps = monad_bind<cont_r><int, int>
(pfcont| square_cps(5), intsqrt_cps)
val five = ev(five_cps, llam(x) =<cloptr1> x)
in
println!(five: int)
end