Skip to content

[functools] Chaining callables #114284

Closed as not planned
Closed as not planned
@ringohoffman

Description

@ringohoffman

Feature or enhancement

Proposal:

This has been proposed before, multiple times, usually under the name functools.compose (see the links).

My inspiration for this was torch.nn.Sequential: Examples of nn.Sequential usage in the wild on GitHub

What I am proposing is functionally equivalent to torch.nn.Sequential, but accepts arbitrary Python callables instead of only instances of torch.nn.Module (a subset of Python callables).

import functools
from typing import Any, Callable, overload

@overload
def sequential[**P, R](*funcs: *tuple[Callable[P, R]]) -> Callable[P, R]:
    ...

@overload
def sequential[**P, R](
    *funcs: *tuple[Callable[P, Any], *tuple[Callable[..., Any], ...], Callable[..., R]],
) -> Callable[P, R]:
    ...

def sequential(*funcs: Callable[..., Any]) -> Callable[..., Any]:
    def compose(*args: Any, **kwargs: Any) -> Any:
        if not funcs:
            raise ValueError("Expected at least 1 callable argument to sequential()")
        return functools.reduce(lambda r, f: f(r), funcs[1:], funcs[0](*args, **kwargs))
    return compose

I am really pleased with the type hinting. I will cover the cases:

no arguments (an error)

sequential() requires at least one argument callable, and type checkers raise a corresponding error

sequential()  # No overloads for "sequential" match the provided arguments

overload 1

a single callable, the sequential has the parameters and return type of that callable

def add(a: int, b: int) -> int:
    return a + b

my_sequential = sequential(
    add,
)

my_sequential(2, 1)  # Returns 3
# reveal_type(my_sequential)  # Type of "my_sequential" is "(a: int, b: int) -> int"

overload 2

multiple callables, the sequential has the parameters of the first callable and the return type of the last callable

def square(a: int) -> int:
    return a * a

def int_to_str(a: int) -> str:
    return str(a)

my_sequential = sequential(
    add,
    square,
    square,
    int_to_str,
)

my_sequential(2, 1)  # Returns "81"
# reveal_type(my_sequential)  # Type of "my_sequential" is "(a: int, b: int) -> str"

overload 2, gone awry and fixed

if the last callable is overloaded, pyright seems to pick the return type of the first overload; the user can cast the last callable to fix this

def square_float(a: float) -> float:
    return a * a

@overload
def only_int_to_str(a: int) -> str:
    ...

@overload
def only_int_to_str(a: float) -> float:
    ...

def only_int_to_str(a: int | float) -> str | float:
    if isinstance(a, int):
        return str(a)
    return a

my_sequential = sequential(
    square_float,
    only_int_to_str,
)
# reveal_type(my_sequential)  # Type of "my_sequential" is "Overload[(a: float) -> str, (a: float) -> float]"

my_sequential(2.0)  # returns 4.0
# reveal_type(my_sequential(2.0))  # type of the first overload: Type of "my_sequential(2)" is "str"

from typing import cast

my_sequential = sequential(
    square_float,
    cast(Callable[..., float], only_int_to_str),  # rectified?
)
# reveal_type(my_sequential)  # Type of "my_sequential" is "(a: float) -> float"

The typing of this solution does nothing to validate that each callable is compatible with the next in the sequence, but I don't think that is a problem.

My use case

I am processing inputs to an LLM using huggingface datasets. datasets.Dataset objects contain data points that can be modeled as TypedDict. To process the dataset, I chain calls to datasets.Dataset.map(), passing in callables that map one TypedDict to another TypedDict.

When I need to operate outside of the datasets.Dataset.map() method chaining paradigm, functools.sequential() allows me to write this beautifully:

data = {
    "system_prompt": "You are a Blockchain Development Tutor. Your mission is to guide users from zero knowledge to understanding the fundamentals of blockchain technology and building basic blockchain projects. Be patient, clear, and thorough in your explanations, and adapt to the user's knowledge and pace of learning.",
    "instruction": "I'm new to blockchain technology. Can you help me understand what it is and how it works?",
    "response": "Sure! Blockchain is a distributed ledger technology that allows for the secure and decentralized storage of data. It's a type of database that is shared across a network of computers, rather than being stored in a single location.",
}

train_tokenizer = functools.sequential(
    SchemaPreprocessor(),  # DatasetSchema -> ModelSchema
    Tokenizer(),  # ModelSchema -> RawModelInput
    TrainFormatter(),  # RawModelInput -> ModelTrainInput
)

llm_input = train_tokenizer(data)
# reveal_type(llm_input)  # Type of "llm_input" is "ModelTrainInput"

Has this already been discussed elsewhere?

No response given

Links to previous discussion of this feature:

bugs.python.org

https://bugs.python.org/issue1660179
https://bugs.python.org/issue35853

github.com/python/cpython/issues

#44584 (comment)
#80034 (comment)

github.com/python/cpython/pull

#11699

stackoverflow.com:

https://stackoverflow.com/questions/51385909/python-passing-arguments-through-a-chain-of-function-calls
https://stackoverflow.com/questions/20454118/better-way-to-call-a-chain-of-functions-in-python
https://stackoverflow.com/questions/72498798/python-chain-several-functions-into-one
https://stackoverflow.com/questions/73996773/most-beautiful-way-of-chaining-a-list-of-functions-in-python-numpy
https://stackoverflow.com/questions/68480108/execute-chain-of-functions-in-python
https://stackoverflow.com/questions/54089072/how-to-consume-chain-functions-pythonically
https://stackoverflow.com/questions/46897852/how-do-you-chain-multiple-functions-and-return-multiple-tuples-from-each-one
https://stackoverflow.com/questions/34613543/is-there-a-chain-calling-method-in-python
https://stackoverflow.com/questions/49112201/python-class-methods-chaining
https://stackoverflow.com/questions/68908195/chaining-functions-calls-on-some-list
https://stackoverflow.com/questions/58473775/chaining-output-between-diffrent-functions
https://stackoverflow.com/questions/17122644/compute-a-chain-of-functions-in-python
https://stackoverflow.com/questions/4021731/execute-functions-in-a-list-as-a-chain
I chose to stop here, but I could probably go on... I just googled chain functions python site:stackoverflow.com

My question

From @Yhg1s, I read recently #96145 (comment):

Adding something to the standard library doesn't just mean "we think this may be a useful tool". It is an endorsement of the tool and the technique. It's seen, not unreasonably so, as a signal that the tool is the right thing to use, and doing the thing it does is the right thing to do

A variation of this implementation using functools.reduce() was suggested as answers to 8 out of the 13 stackoverflow.com questions I linked above. Practically the same solution was previously proposed in #11699.

My question is, does Python endorse this solution? Because in my opinion, a solution is warranted even if this is not it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type-featureA feature request or enhancement

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions