Some Advanced Typing Concepts in Python

14 minute read

To start off, let me admit that yes, Python is a dynamically (gradually?) typed language. However, with modern Python (3.6+, and really with 3.10+) and static tooling (i.e. mypy, pyright/pylance) it is a pretty robust statically typed language. Most third party libraries either are typed or have types stubs distributed separately. From a developer perspective, this version of the language making use of full strict static typing is essentially a different language the Python of yesteryear.

I won't fully make the argument here but treating Python like a statically typed language greatly increases readability and expressiveness of the code, greatly reduces bugs and even the amount of tests you need to write, and makes it much easier to move on to a "real" statically typed language like Rust (do it, its great!) or Go (do it, its pretty good!) or even oldies like C/C++/Java.

While there is good material on more advanced typing concepts in Python (the mypy docs are actually quite good), most tutorials you find online are "hello world" level examples just introducing types to people who have never used Python or are from the dark days of Python pre-3.6. This post assumes that you have a decent grasp of Python and already know the basics of type annotations, or that you are coming from another statically typed language and just want to see if Python is up to snuff. Ok, rant over, lets get into some semi-advanced typing goodness.

All of the examples in this post are available here if you want to play around with them. I would suggest using vscode and the "strict" Pylance type checking mode.

Type annotations and the interface segregation principle

Remember the interface segregation principle from Uncle Bob's SOLID principles. Well if you don't, it basically means that users of your code should not be forced to rely on interfaces/types/classes that they don't need. We will see how Python type annotations can help with this.

Some general good practices

First lets look at how we can make our code more robust by defining loose types for input parameters and strict types for outputs. Lets say we are writing a function to filter out a certain word or phrase from some strings. You could write your function like this

from typing import List

def filter_things(things: List[str], filter_string: str) -> List[str]:
    return [t for t in things if filter_string not in t]

This is fine and would work like this

things = ["badger", "snake", "mushroom"]
filtered_things = filter_things(things, "mushroom")
# returns ["badger", "snake"]

But do we really need the inputs things to be a list, all we are doing is iterating over it. Lots of python types do that. We could guess and do something like

from typing import Any, Dict, List, Set, Tuple

def filter_things(things: List[str] | Tuple[str,...] | Set[str] | Dict[str, Any], filter_string: str) -> List[str]:
    return [t for t in things if filter_string not in t]

but that would be terrible and you have not just confused your user even more. Instead we can use the type that specifies the bare minimum constraint for what we are doing. Since we are just iterating over things we can do

from typing import Iterable, List

def filter_things(things: Iterable[str], filter_string: str) -> List[str]:
    return [t for t in things if filter_string not in t]

Using the Iterable type just states that things is a collection of strings that can be iterated over. Now our function can work with lots of different input types like tuples, sets, dicts, even generators since all of those types are iterable. Note that we do return a concrete type of List[str] here. It is always a good idea to hav the return type be as narrow as possible so the behavior is deterministic no matter what kind of input you give the function.

To hammer this home, lets look at one more example.

from typing import List

def get_fifth_element(things: List[str]) -> str:
    if len(things) < 5:
        raise ValueError(f"Length of input must be >= 5, found len={len(things)}")
    return things[4]

Again, we don't really need a list here but we aren't iterating. Instead we are checking the length and grabbing the value at a given index. So this means that we must be able to get the size of the input collection and we must be able to get the item in the collection at a given index. If we look at the docs we can see that the correct type to use here is Sequence since it implements the __len__ and __getitem__ methods which allow for us to call len on it and to index into it. So we have

from typing import Sequence 

def get_fifth_element(things: Sequence[str]) -> str:
    if len(things) < 5:
        raise ValueError(f"Length of input must be >= 5, found len={len(things)}")
    return things[4]

# good with a tuple
elements = ("earth", "wind", "water", "fire", "multipass")
fifth_element = get_fifth_element(elements)

# or a list
elements = ["earth", "wind", "water", "fire", "multipass"]
fifth_element = get_fifth_element(elements)

As we will see later, even this type is a bit too narrow still, but without defining our own type this will usually be fine.

Protocols

I had a previous post about Protocols but here we will go over them in the context of the interface segregation principle and just give a bit more detail as to why they are useful.

So lets start with the age old example of defining an Animal

import abc

class Animal(abc.ABC):
    @abc.abstractmethod
    def make_sound(self) -> None:
        pass

    @abc.abstractmethod
    def act(self) -> None:
        pass

    @property
    @abc.abstractmethod
    def num_legs(self) -> int:
        pass

Note, I used an ABC here. I could have used a protocol for this but I'm trying to fake a 3rd party library that may already have a lot of different implementation of a class using standard class hierarchy. Now lets implement some animals

class Dog(Animal):
    def make_sound(self) -> None:
        print("woof")

    def act(self) -> None:
        print("dog is walking")

    @property
    def num_legs(self) -> int:
        return 4


class Fish(Animal):
    def make_sound(self) -> None:
        print("blub")

    def act(self) -> None:
        print("fish is swimming")

    @property
    def num_legs(self) -> int:
        return 0

and lets make a function so we can hear some animal sounds

from typing import Iterable

def get_animal_sounds(animals: Iterable[Animal]) -> None:
    for animal in animals:
        animal.make_sound()

animals = [Dog(), Fish()]
get_animal_sounds(animals)
# woof
# blub

So this function takes in an Iterable (see we learned from the first part!) of Animals. If we run this function with these inputs then we will get woof and blub. Cool!, now we want to know what a fox says so we implement a fox

class Fox:
    def make_sound(self) -> None:
        print("redacted")

Sweet, lets plug this in

animals = [Dog(), Fish(), Fox()]
get_animal_sounds(animals)

Uh oh we get an error from our type checker like this

Argument of type "list[Dog | Fish | Fox]" cannot be assigned to parameter "animals" of type "Iterable[Animal]" in function "get_animal_sounds"
  "list[Dog | Fish | Fox]" is incompatible with "Iterable[Animal]"
    TypeVar "_T_co@Iterable" is covariant
      Type "Dog | Fish | Fox" cannot be assigned to type "Animal"
        "Fox" is incompatible with "Animal" PylancereportGeneralTypeIssues

Ah, yes because we just implemented make_sound for Fox and not the whole animal interface. Well thats dumb, I just want to know what the fox says, why do I need to implement the whole "interface". Protocols to the rescue

from typing import Iterable, Protocol

class CanMakeSound(Protocol):
    def make_sound(self) -> None:
        ...

def get_animal_sounds(animals: Iterable[CanMakeSound]) -> None:
    for animal in animals:
        animal.make_sound()

Now, we can find out what the fox says since all three animals now implement the CanMakeSound Protocol.

So of course this is a silly example but there are many real life situation where our functions, classes, etc only need to know about a few methods of the object and don't need to know about all of them. If we use Protocols to only specify what we need this means that a user can implement their own objects to just satisfy this protocol.

If you really are digging this, here is how we could make our get_fifth_element function even more flexible by using a custom protocol that only requires the bare minimum.

from typing import Any, Protocol, TypeVar

T = TypeVar("T", covariant=True)

class MySequence(Protocol[T]):
    def __getitem__(self, __idx: Any) -> T:
        ...

    def __len__(self) -> int:
        ...

def get_fifth_element(things: MySequence[str]) -> str:
    if len(things) < 5:
        raise ValueError(f"Length of input must be >= 5, found len={len(things)}")
    return things[4]

Generics with Protocol Constraints

In the examples above we used Protocols to make our functions very general (you could even say generic!); however, we only used them in the arguments and not in the return type. In general, it doesn't really make sense to return a Protocol. Sometimes you need to use generics with custom constraints. We can illustrate this by implementing the simple max function that returns the maximum of two inputs. First lets just do it for ints

def max(a: int, b: int) -> int:
    if a < b:
        return b
    return a

Ok, simple but the python builtin max function works on lots of other types, like float, str, and even lists and sequences. If we want a generic max function, look at the above function and figure out what the minimum requirement is. If we look, the only requirement is that a and b can be compared using the "less than" operator. In Python, that means that they both must implement the __lt__ magic method. So lets make a bounded generic and re-write our function

from typing import Protocol, TypeVar
from typing_extensions import Self

class HasLessThan(Protocol):
    def __lt__(self, __other: Self) -> bool:
        ...

T = TypeVar("T", bound=HasLessThan)

def max(a: T, b: T) -> T:
    if a < b:
        return b
    return a

Now if we use this function we will get the right return types and it will allow for many different comparisons

m = max(3, 4) # m is of type: int
m = max("hello", "world") # m is of type: string
m = max([4, 5, 6], [1, 2]) # m is of type: List[int]

As always, we could even define our own type. Lets define a line class where we want max to return the longer line object

from dataclasses import dataclass
from typing_extensions import Self

@dataclass
class Line:
    x_min: float
    x_max: float

    def __lt__(self, other: Self) -> bool:
        """Compare based on line length"""
        return (self.x_max - self.x_min) < (other.x_max - other.x_min)

m = max(Line(0, 3), Line(0, 5)) # returns Line(0, 5)

This is a very powerful method for defining generic functions/classes that only specify the bare minimum type constraints. If you are coming from Rust this is similar to generics with trait bounds.

Parameter Specification Variables

As Python is a gradually typed language, it cannot or could not always have static typing in all situations. Until Parameter Specification Variables or just "ParamSpec" we could not get good static typing on functions that are decorated. Once a function was decorated we lost the type information about the input and output parameters of that function. Lets see how we can fix that now with ParamSpec

import time
from typing import Callable, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")

def timer(func: Callable[P, T]) -> Callable[P, T]:
    def inner(*args: P.args, **kwargs: P.kwargs) -> T:
        tstart = time.perf_counter()
        out = func(*args, **kwargs)
        print(f"function: {func.__name__} ran in {(time.perf_counter()-tstart)} seconds")
        return out

    return inner

@timer
def add(a: float, b: float) -> float:
    """
    Add two numbers
    """
    return a + b

out = add(3.4, 4.5) 
# function: add ran in 6.902000677655451e-06 seconds

Here we have created a timer decorator that simply prints the time it took to call the decorated function. When we add the decorator to the simple add function it will print out the runtime. This is a simple example but lets look at what is going on. Notice that the outer function of the decorator takes in the decorated function func and we use a type annotation of Callable[P, T]. T is just a generic variable stating that the function could return any type. The P variable is the ParamSpec. This is what allows to keep that type information in the decorated function. In the inner function we use *args and **kwargs as you always would for unknown inputs but now we can annotate them with P.args and P.kwargs, respectively. And this is the magic that respects the type information of the decorated function. Your code editor (I use vscode and it is amazing for this and really all things...) should show you the function definition when you hover over it and the type checker should respect the input and output types. I used this feature quite heavily in a previous post about implementing a fully typed lru-cache.

Overloads

Overloads let you specify different combinations of inputs and/or outputs for a single function. This can lead to some really bloated functions that do way too much but, like most things, everything in moderation.

Broadly speaking there are two cases where you would want to consider using overloads

  • Optional Second order decorators (i.e. decorators that may or may not take arguments)
  • Functions where there are several different combinations of input parameters which may or may not lead to different output types. Note, overloads should only be used here where the job can't be done with generics.

First lets look at the second-order decorator case by modifying our example above. What if we want the ability to format our message that prints the runtime, or what if we want to use a logger instead of just printing, but we still want to have the default behavior without passing in any arguments. We can modify it as follows

import time
from typing import Callable, Optional, ParamSpec, TypeVar, overload

T = TypeVar("T")
P = ParamSpec("P")

def _default_display_fn(function_name: str, call_time: float) -> None:
    print(f"function: {function_name} ran in {call_time} seconds")

@overload
def timer(__func: Callable[P, T]) -> Callable[P, T]:
    ...

@overload
def timer(*, display_fn: Callable[[str, float], None]) -> Callable[[Callable[P, T]], Callable[P, T]]:
    ...

def timer(
    __func: Optional[Callable[P, T]] = None,
    *,
    display_fn: Optional[Callable[[str, float], None]] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]:

    display_fn = display_fn or _default_display_fn

    def decorator(func: Callable[P, T]) -> Callable[P, T]:
        def inner(*args: P.args, **kwargs: P.kwargs) -> T:
            tstart = time.perf_counter()
            out = func(*args, **kwargs)
            display_fn(func.__name__, time.perf_counter() - tstart)
            return out

        return inner

    if __func is not None:
        return decorator(__func)

    return decorator

There is a lot of detail here but the main point is the two overload decorators. These specify the two signatures of the timer decorator. The first is a simple decorator just like we had before where you can use it with no parentheses. The second case is our second-order decorator which we can pass in a function to display the run time. We could use it as follows

import logging

def logging_display_fn(function_name: str, call_time: float) -> None:
    logging.info(f"function: {function_name} ran in {call_time} seconds")

@timer(display_fn=logging_display_fn)
def add(a: float, b: float) -> float:
    """
    Add two numbers
    """
    return a + b

Now instead of using the default print we will use a logger. In both cases the type signature of the original add function are preserved. Just by looking at the complexity of the actual timer implementation we can start to see why overload can be dangerous since the function can get very complicated and it may be better to split into separate functions. Nonetheless using overload as a way to have both first- and second-order decorators is generally a decent idea.

To demonstrate the second case for overloads lets extend the functionality of our max function so that a user can pass in types that don't natively have support for <, but we allow them to pass in a key function to determine how to compare values of a and b:

from typing import Callable, Protocol, TypeVar, overload

from typing_extensions import Self

class HasLessThan(Protocol):
    def __lt__(self, __other: Self) -> bool:
        ...

T = TypeVar("T", bound=HasLessThan)
S = TypeVar("S")

def _default_key(__val):
    return __val

@overload
def max(a: T, b: T) -> T:
    ...

@overload
def max(a: S, b: S, *, key: Callable[[S], T]) -> S:
    ...

def max(a, b, *, key=None):
    key_func = key or _default_key
    if key_func(a) < key_func(b):
        return b
    return a

In the example above we have two overloads. The first overload is exactly what we had before where the inputs must support the < operator. The second overload now allows a user to pass in any type but they are required to pass in a key function that returns the value that will be used to compare the two inputs.

For example, lets return to our custom line type but this time, we won't implement the __lt__ method but instead use a key function in max

from dataclasses import dataclass

@dataclass
class Line:
    x_min: float
    x_max: float

m = max(Line(0, 3), Line(0, 5)) # this will fail the type check
m = max(Line(0, 3), Line(0, 5), key=lambda line: (line.x_max - line.x_min)) # returns Line(0, 5)

So here our overloads let us be more flexible in our types. Note, that in this case I did not include types in the actual implementation. It is up to you whether or not you do this since the actual implementation signature will not be exposed to the end user, only the overloads.

The max function in python actually has several more overloads if you want to check it out.

Wrapping Up

In this post we have covered some more advanced use cases for type annotations. The TLDR version is

  • Use general input types only specifying the bare minimum of what is needed
  • Use strict concrete output types
  • Use custom Protocols define the minimum interface needed to functions/methods/classes
  • Use custom Protocols as bounds on generic types to make code more general
  • Use Parameter Specification (i.e. ParamSpec) when making decorators
  • Use overloads sparingly when generics cannot do the job. Prefer overloads over usage of Union types, especially in return types

As I mentioned in the introduction, none of this is required for valid Python but making use of the type annotation system will make your code more readable and will make it much nicer to work with.

Updated:

Leave a Comment