In recent weeks, I’ve contributed a prototype module that implemented partial evaluation (PE) in a deep learning DSL for encoding neural networks. The implementation is very intricate. Thus I consider a blog post would be a proper place to discuss the relevant code, design, and examples. Hope that can help you learn about implementing eDSL in Haskell, as well as the principles about partial evaluation.

DeepDarkFantasy

A friend of mine, Marisa Kirisame, is the creator of this project. Initially, it is DeepLearning.scala, but later on the author went back to school and developed a Haskell version.

The idea is quite simple. When designing neural network (NN) layers, we need to know how to adjust the edge weights according to the change of loss. That is, we need to know the derivative of loss w.r.t multiple weights. I recommend another tutorial here if you are not already familiar with this.

More materials:

PE in a De Brujin indexed, final tagless eDSL

Honestly speaking, I didn’t do much theoretical innovation up to now, and my work in DDF is heavily based on the tutorial by Oleg. But I found his tutorial too contrived, more or less like:

Here is how to draw a circle, got it? Let’s draw the Mona Lisa now!

In this post, I will try to show more examples, and be more verbose about details.

Second, while Oleg’s interpretation is already quite beautiful, I believe that another interpretation in a slightly different scenario will be useful for more people to understand the material. This will give you an idea about the common considerations when implementing PE for certain language.

Basics

In this section, I am going to talk about the de-brujin-index based, tagless approach towards a more extensible eDSL. Let’s take DDF.DBI as an example:

class DBI (r :: * -> * -> *) where
  z :: r (a, h) a
  s :: r h b -> r (a, h) b
  abs :: r (a, h) b -> r h (a -> b)
  app :: r h (a -> b) -> r h a -> r h b

Here, DBI encodes what is needed to be a typed lambda calculus embedded inside the host language. This might look confusing to you now, so let’s instead go for the hello-world version of a deeply embedded DSL first:

data Expr = EVar String
          | EAbs String Expr
          | EApp Expr Expr

An example is EApp (EAbs "f" "f") (EAbs "x" "x"), or $(\lambda f. f) (\lambda x. x)$, which will evaluate to $(\lambda x. x)$ under the classical lambda calculus semantics.

Next, we are going to write two interpreters: a pretty-printer and an evaluator for this language.

prettyAST :: Expr -> String
evalAST :: Expr -> Maybe Expr

In each interpreter, we need to do case-analyze on e. So, if we add a new case, every interpreter needs to be updated.

Note: Another way to implement this is through type class, for example:

class Pretty a where
    pretty :: a -> String

In this way, we need less names when we have a complex hierarchy: Expr over Term, Stmt over Expr etc. But the extensibility problem is still not solved.

We hope there is a way to make syntax and semantics highly composable. Let’s think about a language that is composed of different smaller languages, possibly forming a dependency graph. You can choose and pick a subset of it, so long as the dependency is correct (which usually means that your final interpreter implements whatever each smaller language requires). Second, each syntax is associated with its own semantics, either being implemented by the interpreter, or built on top of depended language primitives.

 F  B  E
  \ | /
    A
   / \
 C     D

For the above example, when you want to instantiate the language B, you just needs to define your “ground rules” for C and D. A’s derivation is automatic, since its instance is defined w.r.t C and D.

So here are two things:

  1. Orthogonal composition of AST (and its defined semantics)
  2. No more boilerplate code for the upper-level derivation (like A’s).

To do so, we must have a flexible enough representation, on which we can require that it contains certain some sub-representations, rather then fixing everything from the very beginning.

Another advantage of this encoding, is the ability to piggyback on the host language’s type system. For example:

  abs :: r (a, h) b -> r h (a -> b)

This means you can’t use arbitrary sub-expression as the body of a lambda expression, unless it really has the right type for being the body of a lambda expression (in this case, it should have a free variable of the same type as the lambda’s binder).

But other parts of the DSL look weird: that’s because we encode de-brujin indexing on the type-level. It’s like, we implemented the rules of de-brujin encoding using the type system of the host language, which enforces the type environment checking for us. It is all static – if the host program compiles, then the guest language compiles. In contrast, the STLC example in Pierce’s TAPL has a compiler that checks types at the runtime.

Let’s explain each component in details:

  z :: r (a, h) a
  s :: r h b -> r (a, h) b

For arbitrary outer environment h, we use z to represent the variable bound to the closest binder (index 0). For example, in $\lambda x . \lambda y. y + x$, $y$ in $y + x$ binds to the closest binder, so we encode it using z, and h would be (ty, (tx, h')) for arbitrary h'. Another variable x is bound to the second closest binder, so we simply encode it with type level pair: s z :: r (ty, (tx, h')) tx.

Below is a workable snippet:

t1 :: forall r h. Double r => r (M.Double, (M.Double, h)) M.Double
t1 = app2 doublePlus (z :: r (M.Double, (M.Double, h)) M.Double)
                     (s (z :: r (M.Double, h) M.Double)
                        :: r (M.Double, (M.Double, h)) M.Double)

For abstraction:

  abs :: r (a, h) b -> r h (a -> b)

Under the de brujin encoding, we don’t have to give names to the binder, so the abs only takes the lambda body expression as the sole parameter. To make sense of it type parameter, we provide another example:

id   = abs (z :: r (a, h) a) :: r h (a -> a)
fxy  = abs (abs (z
                    :: r (ty, (tx, h)) ty)
                :: r (tx, h) (ty -> ty))
       :: r h (tx -> (ty -> ty))
fxy' = abs (abs (s (z
                    :: r (tx, h) tx)
                  :: r (ty, (tx, h)) tx)
                :: r (tx, h) (ty -> tx))
       :: r h (tx -> (ty -> tx))

Note the difference of types between two z in the above fxy and fxy'.

app is much more straight-forward.

  app :: r h (a -> b) -> r h a -> r h b

Going through the hard part

Now that we have a rough idea about how such advanced AST construct works, we can consider its partial evaluation. You might hear about it in your Compiler course, in which PE is one of the common techniques for code optimization. For example, in your C code, you can write an expression like 1 + 2. The PE phase can automatically detect such statically reducible expressions, and evaluate it at compile-time, emitting the 3 rather than the more expensive equivalence 1 + 2. This process is called partial, since 99% programs have unknown input parameters, like x + 1, and we can’t evaluate this since we don’t know x statically.

PE is like a bridge between static time and dynamic time. JIT is, as well.

There is also an interesting research model developed for it, called Futamura projections. Just think about it:

compile p = pe (interpret p)

Intuition

Let’s try to formulate the principles of partial evaluation step by step:

The simplest example: consider the PE’ed result of 1 + 2: Let’s analyze it bottom-up: 1 can be PE’ed trivially, so does 2, but for 1 + 2, we need information telling us that both operands are statically known, so we can replace the entire expression with another statically known value. Let’s use $(e)_s$ to denote static expression $e$ and use $(e)_u$ to denote unknown expression $e$. This is how to process $1 + 2$ in this framework:

  1. 1 =pe=> $(1)_s$
  2. 2 =pe=> $(2)_s$
  3. 1 + 2 =pe=> $(1)_s + (2)_s = (1 +_s 2)_s = (3)_s$

Apparently, for $(e)_s$, there mustn’t be any free variable in $e$. That seems plausible, but consider this: $(\lambda x . x + y) \, 2$, since $y$ is free, $(\lambda x . x + y)$ can’t be static, thus the whole expression will just return as it is. But, there is apparently a chance for PE: the optimal result should be $2 + y$.

Aha, our naive formulation turns out to be too restrictive. We should be able to identify “static function” like $\lambda x. x + y$: no matter what is in it, as long as we can track where $x$ binds to, we can do a step of $\beta$-reduction PE when this static function is applied to some other PE’ed result. Note that for function with multiple parameters, we apply one by one. So we only need two rules: one for a static value, and one for a static lambda.

Now we consider a third possibility:

The big $\Lambda$ is the abstraction operator in the host language. During the PE of a static lambda, We lift the static function from DSL space to the host language space, so we can apply it later.

Now, we have $\lambda x . x + y$ PE’ed as:

$(\Lambda x . x + y)_{\lambda \tau_x}$

The operands being applied on can be any PE’ed form, which we will denote as $e_\text{?}$, then we can perform the application, resulting in $(e)_\text{?} + y$.

Clearly, the ? here can only be either static or just unknown. For the first case, we can wrap the thing into $(e +_s y)_s$, and for the second case, we can wrap the thing as an unknown value – wait a minute, unknown value? What if the y binder can be also statically applied in future? Maybe there is some information loss again?

The question is: how to wrap the above expression into a form that preserves the $y$ there. In fact, the acute reader might already ask that a few minutes ago, since we didn’t give any instruction on how we actually get $\Lambda$.

You can imagine that, each time we have such free variable like $y$, we can store its location info (de brujin index) along the wrapped expression. So when we are processing $\lambda$ in the DSL, we will check if the function body has at least one free variable, if so, we will take out the head variable’s location information, which is like a substitution operator, and abstract over it – so later when we enter the application processing step, we can simply apply it at the host level.

Consider $y$, the simplest expression containing a free variable. The $y$ in it might be bound to some binder in succeeding PE process. We will replace it with a special host-level closure $\Lambda h_y. h_y[z]$.

Each free variable or expression containing free variable can react to the fed-in environment. The free variable will see if the given environment is on the same index level. If yes, then we perform the substitution for it.

Now think about $z + z$, how to compose them together?

  1. $(\Lambda h . h[z])_f + (\Lambda h . h[z])_f$
  2. $\Lambda h . h[z] + h[z]$

After step 2, if $h$ represents the index 0 environment (let’s say, $k$), then $h[z]$ will succeed, and the final result will be $k + k$. Or, $h$ has index > 0, thus not on the same level, so the result is still $z + z$.

This is the intuition. Beyond this, we need to handle more complex cases caused by composition. In Oleg’s tutorial, he named three challenges:

  1. The environment depth of function body will change when the function is changed by PE. Just consider $(\lambda x. (\lambda y. x + y + 1) x)$. Initially, x + y + 1 is of type r (Int, (Int, h)) Int, but after the y binder is applied with $x$, result x + x + 1 is of type r (Int, h) Int.
  2. When we substitute free variables under more binders in the function body, the binding depth of the variables changes, so its host-level de-brujin indexing should also be changed. The intuition is: we should increase the indexing by one, to bind the right ones outside the closest binder.
  3. The variable to substitute may be s (s ... (s z)) rather just z. That is, the closure formed by open expression might need to be transformed when wrapped inside another binder.

Partially-evaluated Term

Actually, we have already given enough motivation to talk about my implementation of Oleg’s idea in DeepDarkFantasy. First we give the full partially-evaluated representation P r h a that carries static information.

data P r h a where
  Unk    :: r h a -> P r h a
  Static :: a -> (forall h. r h a) -> P r h a
  StaFun :: (forall hout. EnvT r (a, h) hout -> P r hout b) ->
            P r h (a -> b)
  Open   :: (forall hout. EnvT r h hout -> P r hout a) ->
            P r h a
  • Unk means “Unknown”, term that has zero static information.
  • Static means fully evaluated, statically known value (so you see host-level type a here). Sine it is environment-irrelevant, you can see it needs a h-universal injection forall h. repr h a.
  • StaFun means yet not fully applied, statically known function. It is a host-level closure, which when given a proper environment EnvT, will return the result PE’ed representation P repr hout b.
  • Open means PE’ed term that contains free variables. It is also a host-level closure, which when given a proper environment, instantiates to the substituted form.

Environment

data EnvT repr hin hout where
  Dyn  :: EnvT repr hin hin
  Arg  :: P repr hout a -> EnvT repr (a, hout) hout
  Weak :: EnvT repr h (a, h)
  Next :: EnvT repr hin hout -> EnvT repr (a, hin) (a, hout)

Next, let’s see how the compile-time de-brujin type environment is encoded.

The environment is provided on a per-variable basis plus the its effect on the entire environment. That is to say, we substitute free variables one by one, and either weaken or shrink the type environment gradually according to the input/output change.

Let’s check out some examples to explain what Oleg means in his tutorial:

  1. Dyn is the identity transformer; it also requests the forgetting of all statically known data, converting P repr h a to the form Unk (x :: repr h a)
    • Consider we are ending the PE with a function-valued expression in the end. Then we need to fill in some “unknown dynamic parameters” in order to get the final PE’ed expression out. These parameters are Dyn. (Note that dynamic (StaFun f) = abs $ dynamic (f Dyn))
    • Because it doesn’t really eliminates/instantiates the free variables, it is an identity function
  2. Arg p asks to substitute p for the z free variable and hence removes its type a from the environment”
    • In contrast to Dyn, this time we know the parameter statically, so this can reduce free variable type environment from (a, hout) to hout
  3. Weak requests weakening of the environment, adding some type a at the top”
    • Weakening the environment doesn’t require any condition. Just consider that $(\lambda x.e)() \equiv e$ ($x$ is free in $e$).
    • This is useful when we substitute an open expression $e$ into another lambda $\lambda x : a$, and the outer lambda weakens $e$ by adding type $a$ to the type environment.
  4. “Since we may need to weaken and substitute for free variables other than z, Next increments the environment level at which Weak and Arg should apply”
    • Consider what Next (Arg p) means. It assumes the closest binder which has arbitrary type a, and lifts the transformation effect by one.

Substitution

app_open :: DBI repr =>
            P repr hin r -> EnvT repr hin hout -> P repr hout r

app_open means, substituting (or “apply”) open terms with the environment.

Two simple rules

app_open e Dyn            = Unk (dynamic e)
-- If we know nothing about parameter

app_open (Static es ed) _ = Static es ed
-- If we already know its value statically, then we can just ignore the environment.

Open term

app_open (StaFun fs) h    = abs (fs (Next h))

If we have an explicitly Open term (simplest one: z :: P r h a), then we apply the meta-level closure with this environmental parameter.

Static function

app_open (StaFun fs) h    = abs (fs (Next h))

If we have static function fs, e.g. $\lambda x . e$ ($x$ is not free in $e$), then any environment $h$ has to be shifted one level up with Next when instantiating under $\lambda x$. After that, we wrap it back with an abs.

The process:

StaFun fs :: P r h (a -> b) =>
fs (Next h) :: P r (a, hout) b =>
abs (fs (Next h)) :: P r hout (a -> b)

Unknown term

app_open (Unk e) h        = Unk (app_unk e h) where
  app_unk :: DBI repr =>
             repr hin a -> EnvT repr hin hout -> repr hout a
  app_unk e Dyn      = e
  app_unk e (Arg p)  = app (abs e) (dynamic p)
  app_unk e (Next h) = app (s (app_unk (abs e) h)) z
  app_unk e Weak     = s e

In the last case, we consider how to refine an unknown (fully dynamic) term with environment. First, when we have an unknown term, can we return a PE’ed term that is not unknown? There are only three other possibilities now: open term, static function, or fully static value. The third case is impossible, since we can’t know more than the plain term when given the r h a body in Unk. Second, we can’t make up an open term or static function from a more general type either. So it must be Unk again.

If the parameter is dynamic, then e won’t change since no useful information is supplied. If the parameter is an AST term, then we will lift them up to the representation-level app. For Next and Weak, the condition is similar. Let’s consider app_unk (z + s z + s s z) (Next (Arg (Static _ x))) as an example. By application, we substitute s z, but after substitution, s s z will change as well, becoming s z, since one layer of $\lambda$ is used. The ideal result is z + x + s z.

We encode this logic with app_unk:

app_unk :: DBI repr =>
           repr hin a -> EnvT repr hin hout -> repr hout a
app_unk e Dyn      = e
app_unk e (Arg p)  = app (abs e) (dynamic p)
app_unk e (Next h) = app (s (app_unk (abs e) h)) z
app_unk e Weak     = s e

This is how the above example works in this framework:

app_unk (z + s z + s s z) (Next (Arg (Static _ x))) =>
app (s (app_unk (abs (z + s z + s s z)) (Arg (Static _ x)))) z =>
app (s (app (abs (abs (z + s z + s s z))) (dynamic (Static _ x)))) z
app (s (app (abs (abs (z + s z + s s z))) x)) z
app (s (abs (z + x + s z))) z
app (abs (z + x + s s z)) z
z + x + s z

During this process, we also used another helper function dynamic:

dynamic:: DBI repr => P repr h a -> repr h a
dynamic (Unk x)      = x
dynamic (Static _ x) = x
dynamic (StaFun f)   = abs $ dynamic (f Dyn)
dynamic (Open f)     = dynamic (f Dyn)

pe Function

Next is the final PE function: from the intermediate, hybrid state to the source form. Note how pe executes: first, we write the syntactic form of the AST, and force its type to be P r h a, thus, it will be interpreted as the PE’ed form automatically. Then, using dynamic, it is closed back to the r h a again.

pe :: Double repr => P repr () a -> repr () a
pe = dynamic

PE main process

instance DBI r => DBI (P r) where
  z = Open f where
    f :: EnvT r (a,h) hout -> P r hout a
    f Dyn       = Unk z                 -- turn to dynamic as requested
    f (Arg x)   = x                     -- substitution
    f (Next _)  = z                     -- not my level
    f Weak      = s z

For z free variable, we turn it into an open term, which responds to the environment’s request of instantiation.

  • If the parameter (request) is fully dynamic Dyn, then we just return the old thing as it is, wrapped in Unk.
  • Or, it is static Arg x, forcing hout to be h, we just get x out (this parameter value should already be PE’ed).
  • Or it is Next _, forcing hout ~ (a, h), we just skip it.
  • Or it is Weak, which forces hout ~ (a', (a, h)), then we binds a bit further to leave space for the weakened environment.
  s :: forall h a any. P r h a -> P r (any, h) a
  s (Unk x) = Unk (s x)
  s (Static a ar) = Static a ar
  s (StaFun fs) = abs (fs (Next Weak))
  s p = Open f where
    f :: EnvT r (any, h) hout -> P r hout a
    -- Nothing is statically known, dynamize
    f Dyn              = Unk (s (dynamic p))
    f (Arg _)          = p
    f (Next h)         = s (app_open p h)
    f Weak             = s (s p)

For s free variable, it needs to consider the body to be shifted. If the body is just unknown, then we shift the inner expression and keep unknown; static value is not effected by shifting; For static function, we just wrap another layer of weak environment, so all inner free variables should increase its de-brujin index by 1.

abs and app are intuitive:

  abs (Unk f) = Unk (abs f)
  abs (Static k ks) = StaFun $ \_ -> Static k ks
  abs body = StaFun (app_open body)

  app (Unk f) (Unk x) = Unk (app f x)
  app (StaFun fs) p   = fs (Arg p)
  app (Static _ fs) p = Unk (app fs (dynamic p))
  app e1 e2           = Open (\h -> app (app_open e1 h) (app_open e2 h))

PE of binary operators

binaryPE :: forall h a r.
            DBI r => (a -> a -> a) -> r h (a -> a -> a) ->
                     (forall h. P r h (a -> a -> a)) -> (forall h. a -> P r h a) ->
                     (a -> M.Bool) -> (a -> M.Bool) -> P r h (a -> a -> a)

This is a general mechanism for PE of binary operations, like Double’s doublePlus operator:

class DBI r => Double r where
  double :: M.Double -> r h M.Double
  doublePlus :: r h (M.Double -> M.Double -> M.Double)

We use it like this:

instance Double r => Double (P r) where
  double x = Static x (double x)
  doublePlus = binaryPE (+) doublePlus doublePlus double (== 0.0) (== 0.0)

This tricky thing is that doublePlus has a DSL-level function type, rather than the meta-level type r h M.Double -> r h M.Double -> r h M.Double. So its currying happens in DSL semantics.

This means that we have to expand this into meta-level using StaFun (rationale: built-in operator like (+) is a static function).

binaryPE op opM opPM liftM isLeftUnit isRightUnit = StaFun binaryPE'
  where
    binaryPE' :: forall hout. EnvT r (a, h) hout -> P r hout (a -> a)

After we get access to the compile-time environment EnvT r (a, h) hout, we need to case-analyze it.

    binaryPE' (Arg a)  = StaFun (binaryPE'' a)
      where
        binaryPE'' :: forall hout. P r h a -> EnvT r (a, h) hout -> P r hout a

The first possibility is that it is PE’ed concrete form, in this case, we pass it to future processing. For left possibilities, we just need to consider how to form the curried function. For Dyn, it forces hout to be (a, h). With environment-universal opPM :: (forall h. P r h (a -> a -> a)), we can app it with z, here b is a -> a i.e. returned function type. For Next _, we similarly skip. For Weak, it forces hout to be (a', (a, h)), so we need to use s z to refer to the weakened environment.

    binaryPE' Dyn      = app opPM z
    binaryPE' (Next _) = app opPM z
    binaryPE' Weak     = app opPM (s z)

For the right operand, it is basically the same. We process the case of two PE’ed concrete terms in a separate function:

    binaryPE'' a (Arg b)   = f a b
    binaryPE'' a Dyn       = app2 opPM (s a) z
    binaryPE'' a (Next h') = app2 opPM (s (app_open a h')) z
    binaryPE'' a Weak      = app2 opPM (s (s a)) (s z)

The s a in Dyn case is to fetch over the z binder. If Next, we instantiate free vars in a with h', then we app_open a h'. In fact, similar to Next, we can consider Dyn as app_open = const. For Weak, we just add s to both operands in Dyn case.

Finally, we came to the f. What might the PE’ed concrete form look like?

If both operands are static, then we apply the static operator statically.

    f (Static d1 _) (Static d2 _)      = liftM (d1 `op` d2)

Using left/right identity rule to prune out computation.

    f (Static d1 _) x | isLeftUnit d1  = x
    f x (Static d2 _) | isRightUnit d2 = x

If both arguments are unknown (i.e. cannot be substituted into and hence improved) there is nothing else we can do:

    f (Unk x) (Unk y)                  = Unk (app2 opM x y)

Otherwise, at least one argument may be improved by a substitution. We will look again through the result to see if the static addition becomes possible:

    f e1 e2                            = Open (\h -> app2 opPM (app_open e1 h) (app_open e2 h))