Building a fully typed LRU Cache in Python

10 minute read

In this post we are going to build a fully typed LRU (least recently used) cache (almost) from scratch using Python. We will then create a function decorator that mirrors the builtin functools implementation.

This exercise will cover several advanced python concepts including generic types and other advanced typing concepts, decorators (including second order decorators), and magic methods.

{: .notice--info} This exercise is for learning and demonstration purposes and thus is not optimized for performance. All source code can be found here

What we will end up with

What we will have at the end of this exercise is a clone (at least in terms of the public interface) of the functools lru_cache decorator. We should get these results:

@lru_cache(maxsize=100)
def fib(n: int) -> int:
    if n < 2:
        return n
    return fib(n - 1) + fib(n - 2)

>>> [fib(n) for n in range(16)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]

>>> fib.cache_info()
CacheInfo(hits=28, misses=16, maxsize=100, currsize=16)

In addition to this we also want to retain all type information from the decorated function. In older versions of functools.lru_cache this information is lost so static type checkers and intellisense could not function correctly. In our version we will retain this type information so that static type checkers will work correctly.

Building the cache class

First, we will build a standalone LruCache class to handle that actual heavy work. In most implementations of LRU cache, a hash map (i.e. dictionary) and a doubly linked list are used.

In this case, since the main point of this article is how to use some of the more advanced python features we will use one single built in data type, the OrderedDict

The full implementation is below:

from collections import OrderedDict
from typing import Generic, Hashable, Optional, TypeVar

T = TypeVar("T")


class LruCache(Generic[T]):
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.__cache: OrderedDict[Hashable, T] = OrderedDict()

    def get(self, key: Hashable) -> Optional[T]:
        if key not in self.__cache:
            return None
        self.__cache.move_to_end(key)
        return self.__cache[key]

    def insert(self, key: Hashable, value: T) -> None:
        if len(self.__cache) == self.capacity:
            self.__cache.popitem(last=False)
        self.__cache[key] = value
        self.__cache.move_to_end(key)

    def __len__(self) -> int:
        return len(self.__cache)

    def clear(self) -> None:
        self.__cache.clear()

Ok, there is a lot to unpack here. First lets look at the basic functionality.

# build a new cache with capacity 2 (i.e. we can store at most two values at a time)
>>> cache = LruCache[str](capacity=2)

# put two values in to the cache
# cache keys will now be [1, 2]
>>> cache.insert(1, "hello")
>>> cache.insert(2, "hello again")

# length is 2 since that is the number of values in the cache
>>> print(len(cache))
2

# can retreive 1 since both 1 and 2 are in the cache
# cache keys are now [2, 1] since 2 was called more recently than 1
>>> print(cache.get(2))
hello again

# this will evict key 1 with value "hello"
# cache keys are now [3, 2]
>>> cache.insert(3, "goodbye")

# we now return None since 1 is no longer in the cache
>>> print(cache.get(1))
None

# but we can get the cached value for 3
# cache keys are still [3, 2] since 3 was called more recently than 2
>>> print(cache.get(3))
goodbye

# cache keys are now [2, 3]
>>> print(cache.get(2))
hello again

# this will drop key 3
# cache keys are now [1, 2]
>>> cache.insert(1, "I'm back!")

# get None since 3 was dropped
>>> print(cache.get(3))
None

# finally we can clear the cache
>>> cache.clear()
>>> print(len(cache))
0

We can see that the capacity determines how large the cache is and every call to cache.get or cache.insert puts that key at the top of the queue essentially keeping track of the most recently used keys and when we insert a value when the cache is full (i.e. the length is equal to the capacity) the least recently used value is dropped. This is what we want.

A few important things to notice about the code, first is that LruCache itself is generic over a type T which represents the value type of the cache. In the example above we specify LruCache[str] to indicate the type of value in the insert method and the return type in the get method is a string. We can pass in any type here and our IDE and static type checkers are able to infer and check the input and return types. Pretty awesome huh?

The next thing to notice is the Hashable type. Since we want this class to generic we need to give it the least amount of information possible about what kinds of keys will be uses. At the very least dictionary keys need to be hashable (i.e. implements the __hash__ method), and that is exactly what Hashable checks for. If we try to pass in a key that is not hashable like a list then we will get an error from a static type checker like pyright

Argument of type "list[int]" cannot be assigned to parameter "key" of type "Hashable" in function "get"
    "list[int]" is incompatible with protocol "Hashable"
      "__hash__" is an incompatible type

Another thing to notice here is the use of the magic method __len__. This is a fairly common one but by implementing this method it allows us to call len on an instance of LruCache.

One of ther bit of fanciness here is the use of the double leading underscore on __cache. We don't want to allow outside users to be able to modify the cache, and while there are no truly private variable in python by using the double underscore you will actually get an AttributeError if you try to access it outside of the class.

As for implementation, as mentioned above, we make use of the OrderedDict provided by the collections module. For the most part, OrderedDict is almost obsolete since regular python dictionaries now retain insertion order; however, it is very useful in this case because of the move_to_end method which lets us move any key-value pair to the end of the dictionary which is how we keep track of the lest/most recently used keys in our lru cache. Finally, the popitem method allows for last=False which drops the key-value pair in a FIFO (first-in-first-out) manner as oppsed to the default LIFO order (last-in-first-out) order. This lets us easliy pop off the least recently used key.

Building the function wrapper class

Now that we have a generic LruCache implementation we move on to building a class that will wrap our function that we wish to cache. What we want is something that has the exact same call signature as our original function but also has a cache_info and cache_clear method as well as a __wrapped__ attribute containing a refrerence to the wrapped function itself. This all mirrors the functools.lru_cache API.

from collections.abc import Callable
from typing import Final, Generic, NamedTuple, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")


class CacheInfo(NamedTuple):
    hits: int
    misses: int
    maxsize: int
    currsize: int


class _LruCacheFunctionWrapper(Generic[P, T]):
    def __init__(self, func: Callable[P, T], maxsize: int):
        self.__wrapped__ = func
        self.__cache = LruCache[T](capacity=maxsize)
        self.__hits = 0
        self.__misses = 0
        self.__maxsize: Final = maxsize

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:

        call_args = args + tuple(kwargs.items())

        ret = self.__cache.get(call_args)

        if ret is None:
            self.__misses += 1
            ret = self.__wrapped__(*args, **kwargs)
            self.__cache.insert(call_args, ret)
        else:
            self.__hits += 1

        return ret

    def cache_info(self) -> CacheInfo:
        return CacheInfo(
            hits=self.__hits,
            misses=self.__misses,
            currsize=len(self.__cache),
            maxsize=self.__maxsize,
        )

    def cache_clear(self) -> None:
        self.__cache.clear()
        self.__hits = 0
        self.__misses = 0

Ok, lets walk through this. First, we define the CacheInfo named tuple to match the signature of the functools.lru_cache implementation. NamedTuples are good for things like this where you want an immutable but easy to read/use return type of some function or method. If you need to add methods or want to modify fields its probably better to use a dataclass.

The next thing to notice is that _LruCacheFunctionWrapper (leading underscore is used here to indicate that this is a "private" class) is generic of types P and T which as we see from the __init__ corresponds to the arguments and return type of the wrapped function, respectively. For the P generic we make use of the new (in python 3.10) ParamSpec which will make things much nicer as we will see. From python 3.8 up you can import this from typing_extensions but it is moved into the builtin typing module in python 3.10.

We also use private attributes (leading underscore) again here because we really don't want these attributes to be used outside of this class. Note that __wrapped__ is still accessible since it is a "dunder" attribute. We use the LruCache specifying the type using the generic variable T as our cache.

In this class we make use of another quite popular magic method __call__. The call magic method allows us to call an instance of this class like a function. Furthermore, we have used the power of generics and the new ParamSpec to actually indicate what the argument names and types are via the P.args and P.kwargs methods as well as the return type T. If we were to use this to wrap our function above then static type checkers and intellisense will know the correct types and will actually raise errors if the wrong type is passed in, just like it would on the original function.

>>> wrapped = _LruCacheFunctionWrapper(fib, maxsize=4)
>>> x = wrapped(3) # type signature (int) -> int

This call function itself is pretty simple since most of the work is handled in LruCache. We turn the arguments and keyword arguments into a tuple (i.e. something hashable) and use that as the key in our LRU cache. This does assume that the input arguments and keyword arguments in our original function are hashable. Unfortunately, as far as I know, there is no way to bound a ParamSpec so that it would check if the arguments are hashable. Maybe one day... Nonetheless, we create the call_args tuple and try to get the function return value from the LRU cache. If we get a non None return value then we increment the hits (hit the cache) counter and return the function value. If we don't have the value in the cache we call our wrapped function and insert the value into the cache and increment the misses counter (missed the cache).

The cache_info method simply returns a CacheInfo named tuple with the current state of the function cache. Lastly the clear_cache method just clears the LRU cache and resets the counters.

One last thing to note here before moving onto the final step is that this way of hashing the function arguments if far from optimal in that it assumes that positional arguments are not called as keyword arguments and that keyword arguments are called in the same order. It also does not take into account default keyword arguments. These things are all fixable through the use of something like a frozen set and using the inspect module. Maybe in a follow on post...

Building the decorator

Ok, we are almost there. We have our generic LruCache and a function wrapper that uses that cache. Now we just need our function decorator. In this case, since our decorator takes an input argument maxsize what we are really constructing is a decorator factory or second order decorator.

from collections.abc import Callable
from typing import ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")


def lru_cache(maxsize: int) -> Callable[[Callable[P, T]], _LruCacheFunctionWrapper[P, T]]:
    def decorator(func: Callable[P, T]) -> _LruCacheFunctionWrapper[P, T]:
        return _LruCacheFunctionWrapper(func, maxsize)
    return decorator

Without the type annotations this if just a function that returns another function which wraps our original function. However, through the use of the generic types, as we have done above we can retain and pass forward all of the function argument and return types.

The outer decorator lru_cache returns a callable that takes in a generic callable with arguments P and return type T and returns a generic instance of _LruCacheFunctionWrapper that we have defined above. The inner decorator takes in the actual function to be cached and returns an instance of _LruCacheFunctionWrapper.

This may look somewhat messy but the generic types are quite magical and really will help make your code more safe and "correct". But thats it, we now have created a decorator that will cache function calls with a maxsize of maxsize! Feel free to go back to the beginning and use this on our fibonacci sequence function and play around in an IDE to see the full power of types.

Summary

And there we have it, we have implemented a fully typed version of lru_cache through the heavy use of generics and OrderedDict. We have also made things pythonic by using magic methods like __init__ and __call__ and followed some best practices by separating out the caching implementation from the actual function wrapper.

Updated:

Leave a Comment