When working with syntax trees (such as in a type theory interpreter) you often want to apply some operation to all subtrees of a node, or to all nodes of a certain type. Of course you can do this easily by writing a recursive function. But then you would need to have a case for every constructor, and there can be many constructors.

Instead of writing a big recursive function for each operation, it is often easier to use a traversal function. Which is what this post is about. In particular, I will describe my favorite way to handle such traversal, in the hope that it is useful to others as well.

As a running example we will use the following data type, which represents expressions in a simple lambda calculus

-- Lambda calculus with de Bruijn indices data Exp = Var !Int | Lam Exp | App Exp Exp | Global String deriving Show example_{1}:: Exp example_{1}= Lam $ Var 0 -- The identity function example_{2}:: Exp example_{2}= Lam $ Lam $ Var 1 -- The const function example_{3}:: Exp example_{3}= Lam $ Lam $ Lam $ App (Var 2) (App (Var 1) (Var 0)) -- Function composition

Now, what do I mean by a traversal function?
The base library comes with the `Traversable` class, but that doesn't quite fit our purposes, because that class is designed for containers that can contain any type a. But expressions can only contain other sub-expressions.
Instead we need a monomorphic variant of `traverse` for our expression type:

traverseExp :: Applicative f => (Exp -> f Exp) -> (Exp -> f Exp)

The idea is that `traverseExp` applies a given function to all direct children of an expression.

The uniplate package defines a similar function, `descendM`. But it has two problems: 1) `descendM` has a `Monad` constraint instead of `Applicative`, and 2) the class actually requires you to implement a `uniplate` method, which is more annoying to do.

The ever intimidating lens package has a closer match in `plate`. But aside from the terrible name, that function also lacks a way to keep track of bound variables.

For a language with binders, like the lambda calculus, many operations need to know which variables are bound. In particular, when working with de Bruijn indices, it is necessary to keep track of the number of bound variables. To do that we define

type Depth = Int -- Traverse over immediate children, with depth traverseExpD :: Applicative f => (Depth -> Exp -> f Exp) -> (Depth -> Exp -> f Exp) traverseExpD _ _ (Var i) = pure (Var i) traverseExpD f d (Lam x) = Lam <$> f (d+1) x traverseExpD f d (App x y) = App <$> f d x <*> f d y traverseExpD _ _ (Global x) = pure (Global x)

Once we have written this function, other traversals can be defined in terms of `traverseExpD`

-- Traverse over immediate children traverseExp :: Applicative f => (Exp -> f Exp) -> (Exp -> f Exp) traverseExp f = traverseExpD (const f) 0

And map and fold are just traversals with a specific applicative functor, `Identity` and `Const a` respectively. Recent versions of GHC are smart enough to know that it is safe to `coerce` from a traversal function to a mapping or folding one.

-- Map over immediate children, with depth mapExpD :: (Depth -> Exp -> Exp) -> (Depth -> Exp -> Exp) mapExpD = coerce (traverseExpD :: (Depth -> Exp -> Identity Exp) -> (Depth -> Exp -> Identity Exp)) -- Map over immediate children mapExp :: (Exp -> Exp) -> (Exp -> Exp) mapExp = coerce (traverseExp :: (Exp -> Identity Exp) -> (Exp -> Identity Exp)) -- Fold over immediate children, with depth foldExpD :: forall a. Monoid a => (Depth -> Exp -> a) -> (Depth -> Exp -> a) foldExpD = coerce (traverseExpD :: (Depth -> Exp -> Const a Exp) -> (Depth -> Exp -> Const a Exp)) -- Fold over immediate children foldExp :: forall a. Monoid a => (Exp -> a) -> (Exp -> a) foldExp = coerce (traverseExp :: (Exp -> Const a Exp) -> (Exp -> Const a Exp))

After doing all this work, it is easy to answer questions like "how often is a variable used?"

varCount :: Depth -> Exp -> Sum Int varCount i (Var j) | i == j = Sum 1 varCount i x = foldExpD varCount i x

or "what is the set of all free variables?"

freeVars :: Depth -> Exp -> Set Int freeVars d (Var i) | i < d = Set.empty -- bound variable | otherwise = Set.singleton (i - d) -- free variable freeVars d x = foldExpD freeVars d x

Or to perform (silly) operations like changing all globals to lower case

lowerCase :: Exp -> Exp lowerCase (Global x) = Global (map toLower x) lowerCase x = mapExp lowerCase x

These functions follows a common pattern of specifying how a particular constructor, in this case `Var` or `Global`, is handled, while for all other constructors traversing over the child expressions.

As another example, consider substitution, a very important operation on syntax trees. In its most general form, we can combine substitution with raising expressions to a larger context (also called weakening). And we should also consider leaving the innermost, bound, variables alone. This means that there are three possibilities for what to do with a variable.

substRaiseByAt :: [Exp] -> Int -> Depth -> Exp -> Exp substRaiseByAt ss r d (Var i) | i < d = Var i -- A bound variable, leave it alone | i-d < length ss = raiseBy d (ss !! (i-d)) -- substitution | otherwise = Var (i - length ss + r) -- free variable, raising substRaiseByAt ss r d x = mapExpD (substRaiseByAt ss r) d x

Similarly to `varCount`, we use `mapExpD` to handle all constructors besides variables.
Plain substitution and raising are just special cases.

-- Substitute the first few free variables, weaken the rest substRaiseBy :: [Exp] -> Int -> Exp -> Exp substRaiseBy ss r = substRaiseByAt ss r 0 raiseBy :: Int -> Exp -> Exp raiseBy r = substRaiseBy [] r subst :: [Exp] -> Exp -> Exp subst ss = substRaiseBy ss 0

λ> raiseBy 2 (App (Var 1) (Var 2)) App (Var 3) (Var 4) λ> subst [Global "x"] (App (Var 0) (Lam (Var 0))) App (Global "x") (Lam (Var 0)) λ> substRaiseBy [App (Global "x") (Var 0)] 2 $ App (Lam (App (Var 1) (Var 0))) (Var 2) App (Lam (App (App (Global "x") (Var 1)) (Var 0))) (Var 3)

As a slight generalization, it can also make sense to put `traverseExpD` into a type class. That way we can traverse over the subexpressions inside other data types. For instance, if the language uses a separate data type for case alternatives, we might write

data Exp = ... | Case [Alt] data Alt = Alt Pat Exp class TraverseExp a where traverseExpD :: Applicative f => (Depth -> Exp -> f Exp) -> (Depth -> a -> f a) instance TraverseExp a => TraverseExp [a] where traverseExpD f d = traverse (traverseExpD f d) instance TraverseExp Exp where traverseExpD f d ... traverseExpD f d (Case xs) = Case <$> traverseExpD f d xs instance TraverseExp Alt where traverseExpD f d (Alt x y) = Alt x <$> traverseExpD f (d + varsBoundByPat x) y

Another variation is to track other things besides the number of bound variables. For example we might track the names and types of bound variables for better error messages. And with a type class it is possible to track different aspects of bindings as needed,

class Env env where extend :: VarBinding -> env -> env instance Env Depth where extend _ = (+1) instance Env [VarBinding] where extend = (:) instance Env () where extend _ _ = () traverseExpEnv :: Applicative f => (env -> Exp -> f Exp) -> (env -> Exp -> f Exp) traverseExpEnv f env (Lam name x) = Lam <$> f (extend name env) x traverseExpEnv f env ...

Overall, I have found that after writing `traverseExpD` once, I rarely have to look at all constructors again. I can just handle the default cases by traversing the children.

A nice thing about this pattern is that it is very efficient. The `traverseExpD` function is not recursive, which means that the compiler can inline it. So after optimization, a function like `lowerCase` or `varCount` is exactly what you would have written by hand.

## Comments

I've found a lot of use in (unlawful) traversals like this:

(Bound variables are untouched) where

This also works if you do the type encoding where

Lam :: Exp (Maybe a) -> Exp aand gives you an easy way of implementing functor, foldable, traversable, and monad. (I.e.free :: Traversal (Exp a) (Exp b) a (Exp b))## Reply