Building a fully typed LRU Cache in Python
In this post we are going to build a fully typed LRU (least recently used) cache
(almost) from scratch using Python. We will then create a function decorator
that mirrors the builtin functools
implementation.
This exercise will cover several advanced python concepts including generic types and other advanced typing concepts, decorators (including second order decorators), and magic methods.
{: .notice--info} This exercise is for learning and demonstration purposes and thus is not optimized for performance. All source code can be found here
What we will end up with
What we will have at the end of this exercise is a clone (at least in terms of the
public interface) of the functools lru_cache
decorator. We should get these results:
@lru_cache(maxsize=100)
def fib(n: int) -> int:
if n < 2:
return n
return fib(n - 1) + fib(n - 2)
>>> [fib(n) for n in range(16)]
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
>>> fib.cache_info()
CacheInfo(hits=28, misses=16, maxsize=100, currsize=16)
In addition to this we also want to retain all type information from the decorated
function. In older versions of functools.lru_cache
this information is lost so
static type checkers and intellisense could not function correctly. In our version
we will retain this type information so that static type checkers will work correctly.
Building the cache class
First, we will build a standalone LruCache
class to handle that actual heavy work.
In most implementations of LRU cache, a hash map (i.e. dictionary) and a doubly linked
list are used.
In this case, since the main point of this article is how to use some of the more
advanced python features we will use one single built in data type, the OrderedDict
The full implementation is below:
from collections import OrderedDict
from typing import Generic, Hashable, Optional, TypeVar
T = TypeVar("T")
class LruCache(Generic[T]):
def __init__(self, capacity: int):
self.capacity = capacity
self.__cache: OrderedDict[Hashable, T] = OrderedDict()
def get(self, key: Hashable) -> Optional[T]:
if key not in self.__cache:
return None
self.__cache.move_to_end(key)
return self.__cache[key]
def insert(self, key: Hashable, value: T) -> None:
if len(self.__cache) == self.capacity:
self.__cache.popitem(last=False)
self.__cache[key] = value
self.__cache.move_to_end(key)
def __len__(self) -> int:
return len(self.__cache)
def clear(self) -> None:
self.__cache.clear()
Ok, there is a lot to unpack here. First lets look at the basic functionality.
# build a new cache with capacity 2 (i.e. we can store at most two values at a time)
>>> cache = LruCache[str](capacity=2)
# put two values in to the cache
# cache keys will now be [1, 2]
>>> cache.insert(1, "hello")
>>> cache.insert(2, "hello again")
# length is 2 since that is the number of values in the cache
>>> print(len(cache))
2
# can retreive 1 since both 1 and 2 are in the cache
# cache keys are now [2, 1] since 2 was called more recently than 1
>>> print(cache.get(2))
hello again
# this will evict key 1 with value "hello"
# cache keys are now [3, 2]
>>> cache.insert(3, "goodbye")
# we now return None since 1 is no longer in the cache
>>> print(cache.get(1))
None
# but we can get the cached value for 3
# cache keys are still [3, 2] since 3 was called more recently than 2
>>> print(cache.get(3))
goodbye
# cache keys are now [2, 3]
>>> print(cache.get(2))
hello again
# this will drop key 3
# cache keys are now [1, 2]
>>> cache.insert(1, "I'm back!")
# get None since 3 was dropped
>>> print(cache.get(3))
None
# finally we can clear the cache
>>> cache.clear()
>>> print(len(cache))
0
We can see that the capacity
determines how large the cache is and every call
to cache.get
or cache.insert
puts that key at the top of the queue essentially
keeping track of the most recently used keys and when
we insert a value when the cache is full (i.e. the length is equal to the capacity)
the least recently used value is dropped. This is what we want.
A few important things to notice about the code, first is that LruCache
itself
is generic over a type T
which represents the value type of the cache. In the
example above we specify LruCache[str]
to indicate the type of value
in
the insert
method and the return type in the get
method is a string.
We can pass in any type here and our IDE and static type checkers are able
to infer and check the input and return types. Pretty awesome huh?
The next thing to notice is the Hashable
type. Since we want this class to
generic we need to give it the least amount of information possible about
what kinds of keys will be uses. At the very least dictionary keys need to
be hashable (i.e. implements the __hash__
method), and that is exactly what
Hashable
checks for. If we try to pass in a key that is not hashable
like a list then we will get an error from a static type checker like pyright
Argument of type "list[int]" cannot be assigned to parameter "key" of type "Hashable" in function "get"
"list[int]" is incompatible with protocol "Hashable"
"__hash__" is an incompatible type
Another thing to notice here is the use of the magic method __len__
. This is a fairly
common one but by implementing this method it allows us to call len
on an instance
of LruCache
.
One of ther bit of fanciness here is the use of the double leading underscore on __cache
. We don't want to allow outside users to be able to modify the cache, and while
there are no truly private variable in python by using the double underscore you will
actually get an AttributeError
if you try to access it outside of the class.
As for implementation, as mentioned above, we make use of the OrderedDict
provided
by the collections
module. For
the most part, OrderedDict
is almost obsolete since regular python dictionaries now
retain insertion order; however, it is very useful in this case because of the move_to_end
method which lets us move any key-value pair to the end of the dictionary
which is how we keep track of the lest/most recently used keys in our lru cache. Finally,
the popitem
method allows for last=False
which drops the key-value pair in a FIFO (first-in-first-out) manner as oppsed to the default LIFO order (last-in-first-out) order. This lets us easliy pop off the least recently used key.
Building the function wrapper class
Now that we have a generic LruCache
implementation we move on to building a class
that will wrap our function that we wish to cache. What we want is something that
has the exact same call signature as our original function but also has a cache_info
and cache_clear
method as well as a __wrapped__
attribute containing a refrerence
to the wrapped function itself. This all mirrors the functools.lru_cache
API.
from collections.abc import Callable
from typing import Final, Generic, NamedTuple, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
class CacheInfo(NamedTuple):
hits: int
misses: int
maxsize: int
currsize: int
class _LruCacheFunctionWrapper(Generic[P, T]):
def __init__(self, func: Callable[P, T], maxsize: int):
self.__wrapped__ = func
self.__cache = LruCache[T](capacity=maxsize)
self.__hits = 0
self.__misses = 0
self.__maxsize: Final = maxsize
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
call_args = args + tuple(kwargs.items())
ret = self.__cache.get(call_args)
if ret is None:
self.__misses += 1
ret = self.__wrapped__(*args, **kwargs)
self.__cache.insert(call_args, ret)
else:
self.__hits += 1
return ret
def cache_info(self) -> CacheInfo:
return CacheInfo(
hits=self.__hits,
misses=self.__misses,
currsize=len(self.__cache),
maxsize=self.__maxsize,
)
def cache_clear(self) -> None:
self.__cache.clear()
self.__hits = 0
self.__misses = 0
Ok, lets walk through this. First, we define the CacheInfo
named tuple to match the
signature of the functools.lru_cache
implementation. NamedTuple
s are good for things
like this where you want an immutable but easy to read/use return type of some function
or method. If you need to add methods or want to modify fields its probably better to use
a dataclass
.
The next thing to notice is that _LruCacheFunctionWrapper
(leading underscore is
used here to indicate that this is a "private" class) is generic of types P
and T
which as we see from the __init__
corresponds to the arguments and return type of
the wrapped function, respectively. For the P
generic we make use of the new (in python 3.10) ParamSpec
which will make things much nicer as we will see. From python 3.8 up
you can import this from typing_extensions
but it is moved into the builtin typing
module in python 3.10.
We also use private attributes (leading underscore) again here because we really don't
want these attributes to be used outside of this class. Note that __wrapped__
is
still accessible since it is a "dunder" attribute. We use the LruCache
specifying
the type using the generic variable T
as our cache.
In this class we make use of another quite popular magic method __call__
. The call magic
method allows us to call an instance of this class like a function. Furthermore, we have
used the power of generics and the new ParamSpec
to actually indicate what the argument
names and types are via the P.args
and P.kwargs
methods as well as the return type T
. If we were to use this to wrap our function above then static type checkers and
intellisense will know the correct types and will actually raise errors if the wrong
type is passed in, just like it would on the original function.
>>> wrapped = _LruCacheFunctionWrapper(fib, maxsize=4)
>>> x = wrapped(3) # type signature (int) -> int
This call function itself is pretty simple since most of the work is handled
in LruCache
. We turn the arguments and keyword arguments into a tuple
(i.e. something hashable) and use that as the key in our LRU cache. This does assume
that the input arguments and keyword arguments in our original function are hashable. Unfortunately, as far as I know, there is no way to bound a ParamSpec
so that it would
check if the arguments are hashable. Maybe one day... Nonetheless, we create the call_args
tuple and try to get the function return value from the LRU cache. If we get
a non None return value then we increment the hits
(hit the cache) counter and return the function value. If we don't have the value in the cache we call our wrapped function and
insert the value into the cache and increment the misses
counter (missed the cache).
The cache_info
method simply returns a CacheInfo
named tuple with the current state
of the function cache. Lastly the clear_cache
method just clears the LRU cache and resets the counters.
One last thing to note here before moving onto the final step is that this way of hashing
the function arguments if far from optimal in that it assumes that positional arguments
are not called as keyword arguments and that keyword arguments are called in the same
order. It also does not take into account default keyword arguments. These things are all
fixable through the use of something like a frozen set and using the inspect
module. Maybe in a follow on post...
Building the decorator
Ok, we are almost there. We have our generic LruCache
and a function wrapper that uses
that cache. Now we just need our function decorator. In this case, since our decorator
takes an input argument maxsize
what we are really constructing is a decorator factory or second order decorator.
from collections.abc import Callable
from typing import ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
def lru_cache(maxsize: int) -> Callable[[Callable[P, T]], _LruCacheFunctionWrapper[P, T]]:
def decorator(func: Callable[P, T]) -> _LruCacheFunctionWrapper[P, T]:
return _LruCacheFunctionWrapper(func, maxsize)
return decorator
Without the type annotations this if just a function that returns another function which wraps our original function. However, through the use of the generic types, as we have done above we can retain and pass forward all of the function argument and return types.
The outer decorator lru_cache
returns a callable that takes in a generic callable with
arguments P
and return type T
and returns a generic instance of _LruCacheFunctionWrapper
that we have defined above.
The inner decorator takes in the actual function to be cached and returns an instance of _LruCacheFunctionWrapper
.
This may look somewhat messy but the generic types are quite magical and really will
help make your code more safe and "correct". But thats it, we now have created a decorator
that will cache function calls with a maxsize of maxsize
! Feel free to go back to the beginning and use this on our fibonacci sequence function and play around in an IDE to see
the full power of types.
Summary
And there we have it, we have implemented a fully typed version of lru_cache
through the heavy use of generics and OrderedDict
. We have also made things pythonic by
using magic methods like __init__
and __call__
and followed some best practices by separating out the caching implementation from the actual function wrapper.
Leave a Comment