Haskell/Continuation passing style
From Wikibooks, the open-content textbooks collection
Continuation Passing Style is a format for expressions such that no function ever returns, instead they pass control onto a continuation. Conceptually a continuation is what happens next, for example the continuation for x in (x+1)*2 is add one then multiply by two.
[edit] Starting simple
To begin with, we're going to explore two simple examples which illustrate what CPS and continuations are. Firstly a 'first order' example (meaning there are no higher order functions in to CPS transform), then a higher order one.
[edit] square
Example: A simple module, no continuations
-- We assume some primitives add and square for the example: add :: Int -> Int -> Int add x y = x + y square :: Int -> Int square x = x * x pythagoras :: Int -> Int -> Int pythagoras x y = add (square x) (square y)
And the same function pythagoras, written in CPS looks like this:
Example: A simple module, using continuations
-- We assume CPS versions of the add and square primitives, -- (note: the actual definitions of add'cps and square'cps are not -- in CPS form, they just have the correct type) add'cps :: Int -> Int -> (Int -> r) -> r add'cps x y k = k (add x y) square'cps :: Int -> (Int -> r) -> r square'cps x k = k (square x) pythagoras'cps :: Int -> Int -> (Int -> r) -> r pythagoras'cps x y k = square'cps x $ \x'squared -> square'cps y $ \y'squared -> add'cps x'squared y'squared $ \sum'of'squares -> k sum'of'squares
How the pythagoras'cps example operates is:
1) square x and throw the result into the (\x'squared -> ...) continuation 2) square y and throw the result into the (\y'squared -> ...) continuation 3) add x'squared and y'squared and throw the result into the (\sum'of'squares -> ...) continuation 4) throw the sum'of'squares into the toplevel/program continuation
And one can try it out:
*Main> pythagoras'cps 3 4 print 25
[edit] thrice
Example: A simple higher order function, no continuations
thrice :: (o -> o) -> o -> o thrice f x = f (f (f x))
*Main> thrice tail "foobar" "bar"
Now the first thing to do, to CPS convert thrice, is compute the type of the CPSd form. We can see that f :: o -> o, so in the CPSd version, f'cps :: o -> (o -> r) -> r, and the whole type will be thrice'cps :: (o -> (o -> r) -> r) -> o -> (o -> r) -> r. Once we have the new type, that can help direct you how to write the function.
Example: A simple higher order function, with continuations
thrice'cps :: (o -> (o -> r) -> r) -> o -> (o -> r) -> r thrice'cps f'cps x k = f'cps x $ \f'x -> f'cps f'x $ \f'f'x -> f'cps f'f'x $ \f'f'f'x -> k f'f'f'x
| Exercises |
|---|
| FIXME: write some exercises |
[edit] Using the Cont monad
By now, you should be used to the (meta-)pattern that whenever we find a pattern we like (here the pattern is using continuations), but it makes our code a little ugly, we use a monad to encapsulate the 'plumbing'. Indeed, there is a monad for modelling computations which use CPS.
Removing the newtype and record cruft, we obtain that Cont r a expands to (a -> r) -> r. So how does this fit with our idea of continuations we presented above? Well, remember that a function in CPS basically took an extra parameter which represented 'what to do next'. So, here, the type of Cont r a expands to be an extra function (the continuation), which is a function from things of type a (what the result of the function would have been, if we were returning it normally instead of throwing it into the continuation), to things of type r, which becomes the final result type of our function.
Example: The pythagoras example, using the Cont monad
import Control.Monad.Cont
add'cont :: Int -> Int -> Cont r Int
add'cont x y = return (add x y)
square'cont :: Int -> Cont r Int
square'cont x = return (square x)
pythagoras'cont :: Int -> Int -> Cont r Int
pythagoras'cont x y =
do x'squared <- square'cont x
y'squared <- square'cont y
sum'of'squares <- add'cont x'squared y'squared
return sum'of'squares
*Main> runCont (pythagoras'cont 3 4) print 25
Every function that returns a Cont-value actually takes an extra parameter, which is the continuation. Using return simply throws its argument into the continuation.
How does the Cont implementation of (>>=) work, then? It's easiest to see it at work:
Example: The (>>=) function for the Cont monad
square :: Int -> Cont r Int
square x = return (x ^ 2)
addThree :: Int -> Cont r Int
addThree x = return (x + 3)
main = runCont (square 4 >>= addThree) print
{- Result: 19 -}
The Monad instance for (Cont r) is given below:
instance Monad (Cont r) where return n = Cont (\k -> k n) m >>= f = Cont (\k -> runCont m (\a -> runCont (f a) k))
So return n is a Cont-value that throws n straight away into whatever continuation it is applied to. m >>= f is a Cont-value that runs m with the continuation \a -> f a k, which maybe, receive the result of computation inside m (the result is bound to a) , then applies that result to f to get another Cont-value. This is then called with the continuation we got at the top level (the continuation is bound to k); in essence m >>= f is a Cont-value that takes the result from m, applies it to f, then throws that into the continuation.
| Exercises |
|---|
| To come. |
[edit] callCC
By now you should be fairly confident using the basic notions of continuations and Cont, so we're going to skip ahead to the next big concept in continuation-land. This is a function called callCC, which is short for 'call with current continuation'. We'll start with an easy example.
Example: square using callCC
-- Without callCC square :: Int -> Cont r Int square n = return (n ^ 2) -- With callCC square :: Int -> Cont r Int square n = callCC $ \k -> k (n ^ 2)
We pass a function to callCC that accepts one parameter that is in turn a function. This function (k in our example) is our tangible continuation: we can see here we're throwing a value (in this case, n ^ 2) into our continuation. We can see that the callCC version is equivalent to the return version stated above because we stated that return n is just a Cont-value that throws n into whatever continuation that it is given. Here, we use callCC to bring the continuation 'into scope', and immediately throw a value into it, just like using return.
However, these versions look remarkably similar, so why should we bother using callCC at all? The power lies in that we now have precise control of exactly when we call our continuation, and with what values. Let's explore some of the surprising power that gives us.
[edit] Deciding when to use k
We mentioned above that the point of using callCC in the first place was that it gave us extra power over what we threw into our continuation, and when. The following example shows how we might want to use this extra flexibility.
Example: Our first proper callCC function
foo :: Int -> Cont r String
foo n =
callCC $ \k -> do
let n' = n ^ 2 + 3
when (n' > 20) $ k "over twenty"
return (show $ n' - 4)
foo is a slightly pathological function that computes the square of its input and adds three; if the result of this computation is greater than 20, then we return from the function immediately, throwing the String value "over twenty" into the continuation that is passed to foo. If not, then we subtract four from our previous computation, show it, and throw it into the computation. If you're used to imperative languages, you can think of k like the 'return' statement that immediately exits the function. Of course, the advantages of an expressive language like Haskell are that k is just an ordinary first-class function, so you can pass it to other functions like when, or store it in a Reader, etc.
Naturally, you can embed calls to callCC within do-blocks:
Example: More developed callCC example involving a do-block
bar :: Char -> String -> Cont r Int
bar c s = do
msg <- callCC $ \k -> do
let s' = c : s
when (s' == "hello") $ k "They say hello."
let s'' = show s'
return ("They appear to be saying " ++ s'')
return (length msg)
When you call k with a value, the entire callCC call takes that value. In other words, k is a bit like a 'goto' statement in other languages: when we call k in our example, it pops the execution out to where you first called callCC, the msg <- callCC $ ... line. No more of the argument to callCC (the inner do-block) is executed. Hence the following example contains a useless line:
Example: Popping out a function, introducing a useless line
bar :: Cont r Int bar = callCC $ \k -> do let n = 5 k n return 25
bar will always return 5, and never 25, because we pop out of bar before getting to the return 25 line.
[edit] A note on typing
Why do we exit using return rather than k the second time within the foo example? It's to do with types. Firstly, we need to think about the type of k. We mentioned that we can throw something into k, and nothing after that call will get run (unless k is run conditionally, like when wrapped in a when). So the return type of k doesn't matter; we can never do anything with the result of running k. Actually, k never compute the continuation argument of return Cont-value of k. We say, therefore, that the type of k is:
k :: a -> Cont r b
Inside Cont r b, because k never computes that continuation, type b which is the parameter type of that continuation can be anything independent of type a. We universally quantify the return type of k. This is possible for the aforementioned reasons, and the reason it's advantageous is that we can do whatever we want with the result of computation inside k. In our above code, we use it as part of a when construct:
when :: Monad m => Bool -> m () -> m ()
As soon as the compiler sees k being used in this when, it infers that we want a () arguement type of the continuation taking from the return value of k. The return Cont-value of k has type Cont r (). This arguement type b is independent of the argument type a of k. [1]. The return Cont-value of k doesn't use the continuation which is argument of this Cont-value itself, it use the continuation which is argument of return Cont-value of the callCC. So that callCC has return type Cont r String. Because the final expression in inner do-block has type Cont r String, the inner do-block has type Cont r String. There are two possible execution routes: either the condition for the when succeeds, k doesn't use continuation providing by the inner do-block which finally takes the continuation which is argument of return Cont-value of the callCC, k uses directly the continuation which is argument of return Cont-value of the callCC, expressions inside do-block after k will totally not be used, because Haskell is lazy, unused expressions will not be executed. If the condition fails, the when returns return () which use the continuation providing by the inner do-block, so execution passes on.
If you didn't follow any of that, just make sure you use return at the end of a do-block inside a call to callCC, not k.
[edit] The type of callCC
We've deliberately broken a trend here: normally when we've introduced a function, we've given its type straight away, but in this case we haven't. The reason is simple: the type is rather horrendously complex, and it doesn't immediately give insight into what the function does, or how it works. Nevertheless, you should be familiar with it, so now you've hopefully understood the function itself, here's it's type:
callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a
This seems like a really weird type to begin with, so let's use a contrived example.
callCC $ \k -> k 5
You pass a function to callCC. This in turn takes a parameter, k, which is another function. k, as we remarked above, has the type:
k :: a -> Cont r b
The entire argument to callCC, then, is a function that takes something of the above type and returns Cont r t, where t is whatever the type of the argument to k was. So, callCC's argument has type:
(a -> Cont r b) -> Cont r a
Finally, callCC is therefore a function which takes that argument and returns its result. So the type of callCC is:
callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a
[edit] The implementation of callCC
So far we have looked at the use of callCC and its type. This just leaves its implementation, which is:
callCC f = Cont $ \k -> runCont (f (\a -> Cont $ \_ -> k a)) k
This code is far from obvious. However, the amazing fact is that the implementations for callCC f, return n and m >>= f can all be produced automatically from their type signatures - Lennart Augustsson's Djinn [1] is a program that will do this for you. See Phil Gossart's Google tech talk: [2] for background on the theory behind Djinn; and Dan Piponi's article: [3] which uses Djinn in deriving Continuation Passing Style.
[edit] Example: a complicated control structure
This example was originally taken from the 'The Continuation monad' section of the All about monads tutorial, used with permission.
Example: Using Cont for a complicated control structure
{- We use the continuation monad to perform "escapes" from code blocks.
This function implements a complicated control structure to process
numbers:
Input (n) Output List Shown
========= ====== ==========
0-9 n none
10-199 number of digits in (n/2) digits of (n/2)
200-19999 n digits of (n/2)
20000-1999999 (n/2) backwards none
>= 2000000 sum of digits of (n/2) digits of (n/2)
-}
fun :: Int -> String
fun n = (`runCont` id) $ do
str <- callCC $ \exit1 -> do -- define "exit1"
when (n < 10) (exit1 $ show n)
let ns = map digitToInt (show $ n `div` 2)
n' <- callCC $ \exit2 -> do -- define "exit2"
when (length ns < 3) (exit2 $ length ns)
when (length ns < 5) (exit2 n)
when (length ns < 7) $ do
let ns' = map intToDigit (reverse ns)
exit1 (dropWhile (=='0') ns') -- escape 2 levels
return $ sum ns
return $ "(ns = " ++ show ns ++ ") " ++ show n'
return $ "Answer: " ++ str
Because it isn't initially clear what's going on, especially regarding the usage of callCC, we will explore this somewhat.
[edit] Analysis of the example
Firstly, we can see that fun is a function that takes an integer n. We basically implement a control structure using Cont and callCC that does different things based on the range that n falls in, as explained with the comment at the top of the function. Let's dive into the analysis of how it works.
- Firstly, the
(`runCont` id)at the top just means that we run the Cont block that follows with a final continuation ofid. This is necessary as the result type offundoesn't mention Cont. - We bind
strto the result of the followingcallCCdo-block:- If
nis less than 10, we exit straight away, just showingn. - If not, we proceed. We construct a list,
ns, of digits ofn `div` 2. n'(an Int) gets bound to the result of the following innercallCCdo-block.- If
length ns < 3, i.e., ifn `div` 2has less than 3 digits, we pop out of this inner do-block with the number of digits as the result. - If
n `div` 2has less than 5 digits, we pop out of the inner do-block returning the originaln. - If
n `div` 2has less than 7 digits, we pop out of both the inner and outer do-blocks, with the result of the digits ofn `div` 2in reverse order (a String). - Otherwise, we end the inner do-block, returning the sum of the digits of
n `div` 2.
- If
- We end this do-block, returning the String
"(ns = X) Y", where X isns, the digits ofn `div` 2, and Y is the result from the inner do-block,n'.
- If
- Finally, we return out of the entire function, with our result being the string "Answer: Z", where Z is the string we got from the
callCCdo-block.
[edit] Example: exceptions
One use of continuations is to model exceptions. To do this, we hold on to two continuations: one that takes us out to the handler in case of an exception, and one that takes us to the post-handler code in case of a success. Here's a simple function that takes two numbers and does integer division on them, failing when the denominator is zero.
Example: An exception-throwing div
divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler =
callCC $ \ok -> do
err <- callCC $ \notOk -> do
when (y == 0) $ notOk "Denominator 0"
ok $ x `div` y
handler err
{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: Denominator 0
-}
How does it work? We use two nested calls to callCC. The first labels a continuation that will be used when there's no problem. The second labels a continuation that will be used when we wish to throw an exception. If the denominator isn't 0, x `div` y is thrown into the ok continuation, so the execution pops right back out to the top level of divExcpt. If, however, we were passed a zero denominator, we throw an error message into the notOk continuation, which pops us out to the inner do-block, and that string gets assigned to err and given to handler.
A more general approach to handling exceptions can be seen with the following function. Pass a computation as the first parameter (which should be a function taking a continuation to the error handler) and an error handler as the second parameter. This example takes advantage of the generic MonadCont class which covers both Cont and ContT by default, plus any other continuation classes the user has defined.
Example: General try using continuations.
tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h =
callCC $ \ok -> do
err <- callCC $ \notOk -> do x <- c notOk; ok x
h err
For an example using try, see the following program.
Example: Using try
data SqrtException = LessThanZero deriving (Show, Eq) sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO () sqrtIO throw = do ln <- lift (putStr "Enter a number to sqrt: " >> readLn) when (ln < 0) (throw LessThanZero) lift $ print (sqrt ln) main = runContT (tryCont sqrtIO (lift . print)) return
[edit] Example: coroutines
[edit] Notes
- ↑ Type
ainfers a monomorphic type becausekis bound by a lambda expression, and things bound by lambdas always have monomorphic types. See Polymorphism.