11 January 2016

Reading imperative programs is hard. I don't know if this has always been true or if it's only become true since I started using functional programming more. But I've increasingly found that it takes me a long time to understand what a piece of Python code does, and oftentimes get it wrong.

For example, what does this function do?

def foo(xs):
    total = 0
    count = 0
    for x in xs:
        total += x
        count += 1
    return total / count

That's not so hard -- it compute the average of the collection xs. But this function has a problem, which is that for many practical problems you can't actually use this function as written. For example, if you're averaging over some large number of rows returned by a SQL query, you can't afford to spend the memory to store all of those numbers at once, but you can afford the CPU time to iterate over all of them. So in practice this code actually looks more like this

# Stuff is happening up here
total = 0
count = 0
for row in complicated_query.execute():
    data_point = some_function(row)
    total += data_point
    count += 1
# Do stuff with total / count

As more and more things get inlined into this procedure, the harder and harder it gets to understand what things are being averaged and possibly even that the average is what's getting returned.

It's tempting to say that Haskell solves this problem.

average :: Fractional a => [a] -> a
average = sum xs / fromIntegral (length xs)

Since Haskell lists are lazy, iterating over them can be done without storing the entire list in memory, as the tail of the list hasn't been evaluated yet and the head gets garbage collected once we use it. But, as I learned somewhat recently, this implementation of average leads to space leaks.

Thatt's because while the computations of sum xs and length xs can each individually be done in constant memory, they aren't going to be evaluated interleaved, and so the entire list sticks around while the first one is being evaluated, and it can't be garbage collected because the thunk for the second one still references the head of the whole list.

Thankfully, there are libraries that solve this problem, and it is even given as an example at the top of the documentation for the Foldl package.

import qualified Control.Foldl as L
average = (/) <$> L.sum <*> L.genericLength

Actually, I was unfair to the Python code above. The exact same code unchanged will work with generators as well, which will allow it to work in constant memory. However, it still has a problem when we want to compute the average as well as something else, similar to the problem with the naive Haskell code. In Python, generator objects can only really be iterated over once. If you need to do multiple passes you either need to create the whole generator again or store the results.

Imagine your generator is a recursive search for board layouts that match some particular constraints. Running the generator again would be very costly, but storing the list in memory could also be infeasible. The standard solution seems to be to inline all of your accumulation functions into a single loop, but once you do that your code quickly becomes unreadable.

I wrote a quick version of Foldl in Python. It's not as syntactically pretty as the Haskell version, but it gets the job done quite readably.

import operator

class Fold(object):
    def __init__(self, step, start, done):
        self.step = step
        self.start = start
        self.done = done

    def run(self, iterable):
        return self.done(reduce(self.step, iterable, self.start))

    @classmethod
    def liftA(cls, func, *folds):
        steps = [f.step for f in folds]
        dones = [f.done for f in folds]
        new_start = tuple(f.start for f in folds)
        
        def new_step(acc, x):
            return tuple(substep(subacc, x) for subacc, substep in zip(acc, steps))

        def new_done(acc):
            results = (subdone(subacc) for subacc, subdone in zip(acc, dones))
            return func(*results)

        return cls(new_step, new_start, new_done)

    def premap(self, func):
        return self.__class__(lambda acc, x: self.step(acc, func(x)), self.start, self.done)

    def postmap(self, func):
        return self.__class__(self.step, self.start, lambda acc: func(self.done(acc)))

identity = lambda x: x
maketuple = lambda *args: args

sum = Fold(operator.add, 0, identity)
length = Fold(lambda acc, x: acc + 1, 0, identity)
minimum = Fold(lambda x, y: y if x is None else min(x, y), None, identity)
maximum = Fold(lambda x, y: y if x is None else max(x, y), None, identity)
head = Fold(lambda x, y: y if x is None else x, None, identity)
last = Fold(lambda x, y: y, None, identity)
all = Fold(lambda x, y: x and y, True, identity)
any = Fold(lambda x, y: x or y, False, identity)
average = Fold.liftA(operator.truediv, sum, length)

elem = lambda e: any.premap(lambda x: x == e)
notElem = lambda e: all.premap(lambda x: x != e)