Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8663e21
feat: insert cached attn for transformers mode
h-guo18 Aug 17, 2025
af41e7e
Add transformers.yaml; load weights for factory_model
h-guo18 Aug 18, 2025
4da28dc
minor: clean irrelevant
h-guo18 Aug 18, 2025
c8950b5
address part of review comments;
h-guo18 Aug 23, 2025
3931bac
address review comments
h-guo18 Aug 27, 2025
d8e080a
add sharding; refine insert_cache by adding profiler;
h-guo18 Aug 27, 2025
2ed8bb7
polish: use list instead of ptr for shape collection
h-guo18 Sep 4, 2025
e796d40
feat: flexible cached attn for transformers mode
h-guo18 Sep 8, 2025
37a74c1
configurable default yaml or mode field substitung default yaml
lucaslie Sep 8, 2025
450c44f
transformers mode refined
lucaslie Sep 9, 2025
f3d0448
transformers mode refined
lucaslie Sep 9, 2025
0b46c8c
transformers mode refined
lucaslie Sep 9, 2025
e74aeaf
VLM debugging
lucaslie Sep 9, 2025
6864333
better model_kwargs and from_pretrained init
lucaslie Sep 9, 2025
5c02398
transformers+graph refined with args/kwargs handling
lucaslie Sep 10, 2025
975d023
config fixes
lucaslie Sep 10, 2025
5dcba38
reviewer feedback and unit tests
lucaslie Sep 10, 2025
6bc7a66
more reviewer feedback
lucaslie Sep 10, 2025
b72f41d
Merge branch 'main' into ll/haoguo/transformers_mode
lucaslie Sep 14, 2025
a92e160
correct handling of mistral3 factory
lucaslie Sep 15, 2025
d20ff7e
Merge branch 'main' into ll/haoguo/transformers_mode
lucaslie Sep 15, 2025
f6a35f3
Merge remote-tracking branch 'upstream/main' into ll/haoguo/transform…
lucaslie Sep 17, 2025
e724615
Merge remote-tracking branch 'upstream/main' into ll/haoguo/transform…
lucaslie Sep 17, 2025
59aaab6
unit test skip fix
lucaslie Sep 17, 2025
948727b
Merge branch 'main' into ll/haoguo/transformers_mode
lucaslie Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/auto_deploy/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"--args.model-kwargs.num-hidden-layers=3",
"--args.model-kwargs.num-attention-heads=32",
"--prompt.sp-kwargs.max-tokens=128",
// "--yaml-extra=config.yaml", // uncomment to load a custom extra yaml config file
// "--dry-run", // uncomment to print the final config and return
],
"console": "integratedTerminal",
Expand Down
7 changes: 4 additions & 3 deletions examples/auto_deploy/build_and_run_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Iterator, List, Optional, Union

import torch
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic_settings import (
Expand Down Expand Up @@ -243,7 +244,7 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM:
"demollm": DemoLLM,
"trtllm": LLM,
}
llm = llm_lookup[config.args.runtime](**config.args.to_dict())
llm = llm_lookup[config.args.runtime](**config.args.to_llm_kwargs())
return llm


Expand All @@ -260,8 +261,8 @@ def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[

def main(config: Optional[ExperimentConfig] = None):
if config is None:
config = CliApp.run(ExperimentConfig)
ad_logger.info(f"{config=}")
config: ExperimentConfig = CliApp.run(ExperimentConfig)
ad_logger.info(f"AutoDeploy Experiment Config:\n{yaml.dump(config.model_dump())}")

if config.dry_run:
return
Expand Down
12 changes: 11 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Additional default args for AutoDeployConfig/LlmArgs in _torch/auto_deploy/llm_args.py
# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph
# of the model and optimize it for inference.
transforms:
############################################################################################
# BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP
############################################################################################
build_model:
stage: factory
device: meta
use_strict_forward: true
# nothing to clean up
run_graph_cleanup: false
requires_clean_graph: false
Expand Down Expand Up @@ -74,6 +79,8 @@ transforms:
############################################################################################
load_weights:
stage: weight_load
move_inputs_to_device:
stage: weight_load
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
Expand Down Expand Up @@ -123,3 +130,6 @@ transforms:
############################################################################################
compile_model:
stage: compile
forward_with_cached_sequence_interface:
stage: compile
args_only: true
33 changes: 33 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/transformers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This is the set of transforms running in "transformers" mode. In this mode, we hook into the
# HF attention mechanism and replace it with our custom cached attention mechanism.
transforms:
############################################################################################
# BUILD MODEL, LOAD WEIGHTS, AND WRAP IT INTO FAKE GRAPH MODULE
############################################################################################
build_and_load_factory_model:
stage: factory
use_strict_forward: false
############################################################################################
# MOVE ARGUMENTS TO DEVICE
############################################################################################
move_inputs_to_device:
stage: weight_load
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
detect_hf_attn_layers:
stage: cache_init
transformers_replace_cached_attn:
stage: cache_init
attn_backend: flashinfer
initialize_cache:
stage: cache_init
resize_kv_cache:
stage: cache_init
args_only: false # use kwargs instead of args
############################################################################################
# COMPILE MODEL
############################################################################################
forward_with_cached_sequence_interface:
stage: compile
args_only: false # use kwargs instead of args
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def __init__(
# indicator if extra args are activated that are needed for cached attention backends
self._is_cached_attn = False

# indicator how to handle the "None" input for extra args
self._use_strict_args = True

# container for dynamic shapes
self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None

Expand All @@ -166,7 +169,7 @@ def __init__(
############################################################################################

# EXTRA TENSOR FIELDS ######################################################################
self._extra_args: Dict[str, torch.Tensor] = {}
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
Expand All @@ -179,6 +182,33 @@ def __init__(
def device(self) -> torch.device:
return self._args_device["input_ids"].device

@property
def use_strict_args(self) -> bool:
return self._use_strict_args

@use_strict_args.setter
def use_strict_args(self, val: bool) -> None:
"""Configure whether to use strict graph arguments only.

Args:
val: strict graph arguments only or not.

In strict arguments mode,
* only stock arguments (like input_ids, position_ids, etc.) or extra
arguments that are explicitly added via the ``add_extra_arg`` interface are allowed.
Other arguments that are provided in ``nest_sequences`` will be rejected and throw an
error.
* registered extra arguments that are not provided to ``nest_sequences`` will be added to
the argument list automatically using the registered None-like tensor.

In non-strict argument mode,
* all arguments including all **kwargs that are provided to ``nest_sequences`` and will
simply be passed to the model in the order received.
* registered extra arguments that are not provided to ``nest_sequences`` will be added
_not_ be added to the argument list.
"""
self._use_strict_args = val

def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
"""Shape the tensor for the forward pass based on the current attention mode.

Expand Down Expand Up @@ -458,7 +488,8 @@ def switch_to_cached_attn_inputs(self) -> List[str]:
def to(self, *args, **kwargs) -> None:
def _move_dict(d: Dict[str, torch.Tensor]) -> None:
for k, v in d.items():
d[k] = v.to(*args, **kwargs)
if v is not None:
d[k] = v.to(*args, **kwargs)

_move_dict(self._args_device)
_move_dict(self._extra_args)
Expand Down Expand Up @@ -557,8 +588,10 @@ def _store_extra_arg(
else:
tnsr_like = tnsr_like[0]
self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
else:
elif self.use_strict_args:
self._extra_args[name] = self._extra_none_inputs[name]
else:
self._extra_args[name] = None

@nvtx_range("ad_nest_sequences")
def nest_sequences(
Expand Down Expand Up @@ -615,10 +648,16 @@ def nest_sequences(
self._store_arg("position_ids", self._flatten(position_ids))

### UPDATE EXTRA INPUTS ####################################################################
# go through all extra tensor arguments and update them
for name in self._extra_none_inputs.keys():
self._store_extra_arg(name, extra_args.pop(name, None))
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
self._extra_args = {}
# in strict argument mode, we only accept registered extra arguments
if self.use_strict_args:
for name in self._extra_none_inputs.keys():
self._store_extra_arg(name, extra_args.pop(name, None))
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
# otherwise, we simply pass in all extra arguments
else:
for key, value in extra_args.items():
self._store_extra_arg(key, value)

@nvtx_range("ad_rescatter_input_ids")
def rescatter_input_ids(
Expand Down Expand Up @@ -670,7 +709,7 @@ def add_extra_arg(
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"

self._extra_args[name] = none_input.to(self.device)
self._extra_none_inputs[name] = none_input.to(self.device)
self._extra_none_inputs[name] = self._extra_args[name]

if dynamic_shape_callback is None:
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/distributed/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,18 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["LOCAL_RANK"] = str(local_rank)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl", world_size=world_size, rank=local_rank)
dist.init_process_group(
"nccl",
world_size=world_size,
rank=local_rank,
device_id=torch.device(local_rank),
)

# Register cleanup function to be called at exit
atexit.register(cleanup)
Expand Down
44 changes: 34 additions & 10 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig
from ...llmapi.utils import get_type_repr
from .models import ModelFactory, ModelFactoryRegistry
from .transform.interface import TransformConfig
from .utils._config import DynamicYamlMixInForSettings

PathLike = Union[str, Path]
Expand All @@ -21,7 +20,6 @@ def _get_config_dict() -> SettingsConfigDict:
return SettingsConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
yaml_file=str(files("tensorrt_llm._torch.auto_deploy.config") / "default.yaml"),
nested_model_default_partial_update=True,
)

Expand Down Expand Up @@ -184,7 +182,14 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
visualize: bool = Field(default=False, description="Whether to visualize the model graph.")

### NEW INFERENCE OPTIMIZER CONFIG #############################################################
transforms: Dict[str, TransformConfig] = Field(
mode: Literal["graph", "transformers"] = Field(
default="graph",
description="The mode to use for the inference optimizer. Currently, we "
"support only the 'graph' and 'transformers' modes, i.e., full-graph capture + optimization"
"or transformers-only cached attention optimization.",
)

transforms: Dict[str, Any] = Field(
default_factory=dict,
description="A dictionary of transform configurations. The key is the transform name and "
"the value is the transform configuration.",
Expand All @@ -205,6 +210,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):

### VALIDATION #################################################################################
@model_validator(mode="after")
# TODO: discuss what to do with this once we fully transition to the new inference optimizer
def update_attn_page_size(self):
# NOTE force attn_page_size to equal max_seq_len for triton backend
if self.attn_backend == "triton" or self.attn_backend == "torch":
Expand Down Expand Up @@ -240,9 +246,27 @@ def to_dict(self) -> Dict[str, Any]:
"""Convert the arguments to a dictionary."""
return self.model_dump()

def to_llm_args(self) -> "LlmArgs":
"""Convert the arguments to a LlmArgs instance that is used for the LLM API."""
return LlmArgs(**self.to_dict())
def to_llm_kwargs(self) -> Dict[str, Any]:
"""Convert the arguments to a dictionary that can be used as kwargs for the LLM API."""
kwargs = self.to_dict()

# ensure we remove the mode and yaml_default fields since they otherwise may conflict each
# other.
if "mode" not in self.model_fields_set:
kwargs.pop("mode")
if "yaml_default" not in self.model_fields_set:
kwargs.pop("yaml_default")
return kwargs

### PRIVATE METHODS ############################################################################
@classmethod
def _get_yaml_default_from_mode(cls, mode: Optional[str]) -> Optional[str]:
config_path = files("tensorrt_llm._torch.auto_deploy.config")
mapping = {
"graph": str(config_path / "default.yaml"),
"transformers": str(config_path / "transformers.yaml"),
}
return mapping.get(mode)


class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
Expand Down Expand Up @@ -352,11 +376,11 @@ def validate_and_init_tokenizer(self):
def get_pytorch_backend_config(self) -> "LlmArgs":
"""Return the LlmArgs (self) object."""
# TODO: can we just pass through self directly??
return type(self)(**self.to_dict())
return type(self)(**self.to_llm_kwargs())

def to_dict(self) -> Dict:
"""Convert model to a dictionary such that cls(**self.to_dict()) == self."""
self_dict = dict(self)
self_dict.pop("build_config")
self_dict.pop("mpi_session")
self_dict = super().to_dict()
self_dict.pop("build_config", None)
self_dict.pop("mpi_session", None)
return self_dict
31 changes: 31 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,35 @@ def _build_model(self, device: str) -> nn.Module:
"""Factory-specific model building logic."""
raise NotImplementedError("Subclasses must implement this method.")

def _set_strict_forward(self, model: nn.Module):
"""Set the strict (args-only) forward method for the model.

For some factories, the regular forward is sufficient. For others, this needs to be set.
The strict forward method should precisely define a fixed args-only, tensor-only signature
compatible with the model's forward method AND the export behavior, which requires fixed
tensor-only positional arguments.

The function should overwrite the `model.forward` method.

The overwritten forward should have `input_ids` and `position_ids` as initial positional
arguments as defined by the sequence interface. Hence the signature should be something like

.. code-block:: python

def _strict_forward(
self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor
) -> Sequence[torch.Tensor]: ...

where `extra_args` are the extra arguments that are defined by the factory and should also
be defined in the `get_extra_inputs` + `get_example_inputs` methods. The actual
`_strict_forward` method should not use `*args` or `**kwargs` but instead use the defined
extra arguments in the order they are defined.

This is necessary as graph export is going to flatten arguments into a list of tensors and
by using a strict forward convention we simplify the export behavior and subsequent handling
of the arguments in the graph module.
"""

def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or None if not quantized."""
return {}
Expand Down Expand Up @@ -173,6 +202,8 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
the same model that is built above but it needs to have a state dict compatible with
the model built above.
device: The device to load the model on.
load_factoy_model: If True, will load weights for the factory model in addition to main
gm. This is useful for the transformers model.

NOTE: we always call ``self._to_maybe_random(model, device)`` as a preprocessing step
to ensure the model parameters already exist on the right device and have the desired dtype
Expand Down
Loading