Skip to content

Conversation

yweng0828
Copy link
Collaborator

@yweng0828 yweng0828 commented Oct 14, 2025

Summary by CodeRabbit

  • New Features
    • Added a configurable max_total_draft_tokens across speculative decoding modes, enabling finer control of total draft tokens.
    • Introduced automatic detection of linear-tree mode and validation to ensure compatible configurations.
  • Refactor
    • Standardized speculative decoding to use total draft tokens instead of per-draft length for allocation, padding, and execution paths.
    • Streamlined initialization and propagation of draft-token limits through sampling and execution, improving consistency across engines and runners.

Description

The purpose of this PR is to add a max_total_draft_tokens field to support draft token trees.

max_total_draft_tokens indicates the total number of draft tokens generated after passing through the draft layers. The existing max_draft_len will still exist, indicating the number of draft layers.

For linear trees, max_draft_len will be equal to max_total_draft_tokens (each draft layer will only generate one draft token).

Note: This has not been modified or adapted for the one-model path, as max_total_draft_tokens currently only works on the two-model path.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@yweng0828 yweng0828 marked this pull request as ready for review October 14, 2025 11:48
@yweng0828 yweng0828 requested review from a team as code owners October 14, 2025 11:48
Copy link
Contributor

coderabbitai bot commented Oct 14, 2025

📝 Walkthrough

Walkthrough

Replaces max_draft_len usage with max_total_draft_tokens across configs, executors, engines, CUDA graph paths, samplers, and speculative components. Adds fields/params, updates constructors and flow to propagate total draft tokens, adjusts warmup/padding/runtime decisions, and updates model code to derive predicted tokens from the new config value.

Changes

Cohort / File(s) Summary
LLM API configs
tensorrt_llm/llmapi/llm_args.py
Adds DecodingBaseConfig.max_total_draft_tokens. Introduces init overrides across decoding configs to set/propagate max_total_draft_tokens (often from max_draft_len). Adds is_linear_tree properties, assertion for eagle3_one_model, and updates from_dict for MTP.
PyExecutor core and utils
tensorrt_llm/_torch/pyexecutor/_util.py, tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py, tensorrt_llm/_torch/pyexecutor/sampler.py
Propagates max_total_draft_tokens into PyExecutor constructor, sampler args, and creator. Replaces estimation, padding, and allocation logic from max_draft_len to max_total_draft_tokens. Adds extra gating (is_linear_tree) for ChainDrafter and updates guided decoding config.
Engine and CUDA graph
tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Introduces original_max_total_draft_tokens and switches warmup/runtime/CUDA-graph sizing from max_draft_len to total-draft-tokens. Updates runtime_draft_len and graph-key computations to use total tokens where applicable; preserves first-draft behavior commentary.
Speculative interfaces and drafters
tensorrt_llm/_torch/speculative/interface.py, tensorrt_llm/_torch/speculative/utils.py, tensorrt_llm/_torch/speculative/drafter.py, tensorrt_llm/_torch/speculative/drafting_loops.py, tensorrt_llm/_torch/speculative/model_drafter.py, tensorrt_llm/_torch/speculative/save_hidden_state.py
Adds SpecMetadata.max_total_draft_tokens and is_spec_dec_tree; updates attention decision signature. Passes max_total_draft_tokens in metadata and drafter construction. Renames/extends drafter params to include total tokens, asserts relationships, and stores attributes.
Auto-deploy shim
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Simplifies max_total_draft_tokens computation and passes it into TorchSampler.Args and PyExecutor. Formatting tweaks.
Model update
tensorrt_llm/_torch/models/modeling_deepseekv3.py
Uses spec_config.max_total_draft_tokens to compute predicted_tokens_per_seq, falling back to 1 when spec_config is None.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor App
  participant Creator as _util.create_py_executor_instance
  participant Engine as ModelEngine
  participant Exec as PyExecutor
  participant Sampler as TorchSampler

  App->>Creator: create_py_executor_instance(spec_config,…)
  Note right of Creator: Derive max_total_draft_tokens (default 0 if None)
  Creator->>Exec: PyExecutor(max_total_draft_tokens,…)
  Creator->>Sampler: TorchSampler.Args(max_total_draft_tokens,…)
  Creator->>Engine: ModelEngine(spec_config with max_total_draft_tokens)
  Engine-->>Creator: instance with original_max_total_draft_tokens
  Exec-->>App: executor ready
Loading
sequenceDiagram
  autonumber
  participant Exec as PyExecutor
  participant Drafter as Drafter/ModelDrafter
  participant Cuda as CudaGraphRunner
  participant Eng as ModelEngine

  Exec->>Drafter: should_use_spec_decode(max_total_draft_tokens,…)
  alt Spec decode enabled
    Exec->>Eng: runtime_draft_len = max_total_draft_tokens
    Exec->>Cuda: get_graph_key(max_total_draft_tokens,…)
    Cuda-->>Exec: key reflecting total draft tokens
    Exec-->>Exec: allocate/pad draft tokens up to total
  else Autoregressive
    Exec-->>Exec: set draft tokens to 0
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.87% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The pull request description provides a clear explanation of the change in the Description section and includes the PR Checklist, but it omits the required PR title line at the top following the “[ticket][type] Summary” template and leaves the Test Coverage section empty, so it does not fully adhere to the repository’s description template. Please add a properly formatted PR title at the top using the “[JIRA/NVBugs/GitHub issue/None][type] Summary” pattern and populate the Test Coverage section with the relevant test cases or coverage details for the new max_total_draft_tokens functionality.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title “[None][feat] Add max_total_draft_tokens” follows the required pattern and succinctly summarizes the primary change of introducing the new field, making it clear and specific for reviewers scanning the history.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (19)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

704-723: Bug: draft_lengths logic overwrites “0” capture and misuses total-tokens for CDL loop iterations

Two issues:

  • Draft lengths for CUDA graph warmup (non-draft model) overwrite a previously appended 0, losing the ability to replay graphs with spec decode disabled at runtime (contradicts the comment).
  • For CDL/Eagle3 path, draft_len represents number of drafting iterations. It should use original_max_draft_len, not original_max_total_draft_tokens.

Fix with this diff:

@@
-            if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
-                    spec_resource_manager, Eagle3ResourceManager):
-                # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
-                draft_lengths.append(self.original_max_total_draft_tokens)
+            if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
+                    spec_resource_manager, Eagle3ResourceManager):
+                # The CDL path uses draft_len to mean "number of drafting iterations"
+                draft_lengths.append(self.original_max_draft_len)
             else:
                 draft_lengths.append(self.max_draft_len)
@@
-            if (self.max_total_draft_tokens > 0
-                    and not self.spec_config.spec_dec_mode.use_one_engine()
-                    # Assume that speculation is always on if the user didn't give us a max_concurrency
-                    # value. This will save on memory.
-                    and self.spec_config.max_concurrency is not None):
-                draft_lengths.append(0)
-            draft_lengths = [self.max_total_draft_tokens]
+            # Always capture the current speculative setting. Optionally also capture draft_len=0
+            # to allow disabling speculative decoding at runtime in the two-engine path.
+            draft_lengths = [self.max_total_draft_tokens]
+            if (self.max_total_draft_tokens > 0
+                    and not self.spec_config.spec_dec_mode.use_one_engine()
+                    and self.spec_config.max_concurrency is not None):
+                draft_lengths.append(0)

1-5: Missing NVIDIA Apache-2.0 header

Per coding guidelines, prepend the NVIDIA Apache-2.0 header with current year to this source file.

Apply at file top:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

1-5: Missing NVIDIA Apache-2.0 header

Please add the required header at the top of the file.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/speculative/drafting_loops.py (1)

1-5: Missing NVIDIA Apache-2.0 header

Add the required header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/speculative/drafter.py (3)

1-5: Missing NVIDIA Apache-2.0 header

Add the required header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

29-55: Tests: rename max_draft_len to max_total_draft_tokens
In tests/unittest/_torch/speculative/test_dynamic_spec_decode.py, update the mock signature and all calls (e.g. lines 75, 198–234) to use max_total_draft_tokens instead of max_draft_len.


65-71: Pad draft tokens to max_total_draft_tokens in CUDA-graph padding
In drafter.py (57–71), pad_draft_tokens_for_cuda_graph uses self.max_draft_len. For tree-based speculation (where max_total_draft_tokens > max_draft_len) this under-pads py_draft_tokens and breaks CUDA-graph static shapes—use max_total_draft_tokens instead or guard by tree mode.

tensorrt_llm/_torch/pyexecutor/sampler.py (1)

1-5: Missing NVIDIA Apache-2.0 header

Please add the required header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.


224-226: Bug: position_ids becomes a tuple due to trailing comma

The trailing comma makes position_ids a tuple (tensor,) instead of a tensor, likely breaking consumers.

-            sliced_static_tensors["position_ids"] = self.shared_static_tensors[
-                "position_ids"][:, :, :num_tokens_for_capture],
+            sliced_static_tensors["position_ids"] = self.shared_static_tensors[
+                "position_ids"][:, :, :num_tokens_for_capture]
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.

tensorrt_llm/_torch/speculative/interface.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.

tensorrt_llm/_torch/speculative/utils.py (4)

1-1: Missing NVIDIA Apache-2.0 header

Add the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.


81-93: Bug: duplicate keyword argument max_total_draft_tokens

max_total_draft_tokens is passed twice; Python will raise “multiple values for keyword”. Decide the intended value (likely 1 for SaveHiddenStates) and keep only one.

-        return Eagle3SpecMetadata(
-            max_draft_len=spec_config.max_draft_len,
-            max_total_draft_tokens=spec_config.max_total_draft_tokens,
+        return Eagle3SpecMetadata(
+            max_draft_len=spec_config.max_draft_len,
             spec_dec_mode=spec_config.spec_dec_mode,
             max_num_requests=max_num_requests,
             num_layers=model_config.num_hidden_layers,
             hidden_size=model_config.hidden_size,
             max_num_tokens=max_num_tokens,
             dtype=model_config.torch_dtype,
             is_draft_model=is_draft_model,
             eagle3_resource_manager=spec_resource_manager,
             layers_to_capture=spec_config.eagle3_layers_to_capture,
             max_total_draft_tokens=1,
         )

94-101: Missing max_total_draft_tokens when constructing SpecMetadata

SpecMetadata now requires max_total_draft_tokens. Include it for these modes.

-        return SpecMetadata(
+        return SpecMetadata(
             max_draft_len=spec_config.max_draft_len,
+            max_total_draft_tokens=spec_config.max_total_draft_tokens,
             spec_dec_mode=spec_config.spec_dec_mode,
             max_num_requests=max_num_requests,
         )

134-144: Include the new is_spec_dec_tree argument in attention_need_spec_dec_mode calls
In tensorrt_llm/_torch/pyexecutor/model_engine.py:2166, the call to spec_dec_mode.attention_need_spec_dec_mode only passes four parameters—append the new boolean flag at the end to match the updated signature.

tensorrt_llm/_torch/speculative/model_drafter.py (1)

473-474: Critical: AttributeError on self.max_draft_tokens.

Line 473 references self.max_draft_tokens, but the constructor (lines 74-75) no longer sets this attribute. After the parameter rename, the attributes are now self.max_draft_len and self.max_total_draft_tokens. This code will fail at runtime with AttributeError: 'ModelDrafter' object has no attribute 'max_draft_tokens'.

The same issue exists at lines 567, 650, and 726. Based on context, these should likely use self.max_total_draft_tokens since they allocate or iterate over draft token buffers.

Apply this diff to fix line 473:

         new_tokens_lens = torch.ones(batch_size, device=device)
         next_draft_tokens = torch.zeros(batch_size,
-                                        self.max_draft_tokens,
+                                        self.max_total_draft_tokens,
                                         device=device)

Similar fixes are needed at:

  • Line 567: for token_idx in range(self.max_draft_tokens):range(self.max_total_draft_tokens):
  • Line 650: for i in range(self.max_draft_tokens - 1):range(self.max_total_draft_tokens - 1):
  • Line 726: draft_length=self.max_draft_tokens,draft_length=self.max_total_draft_tokens,
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)

106-108: Clarify/typo in comment

Nit: use plural “max_total_draft_tokens” and tighten wording.

-            # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
-            # Because we will pad the input to 'max_draft_len' length for the first draft layer.
+            # For the first draft in tree decoding, draft_len equals 'max_draft_len' (number of layers), not 'max_total_draft_tokens',
+            # because we pad inputs to 'max_draft_len' for the first draft layer.
tensorrt_llm/_torch/pyexecutor/py_executor.py (2)

150-171: Constructor: add brief docstring for new arg

New parameter max_total_draft_tokens is public API. Add a short docstring entry.

 class PyExecutor:
 
     def __init__(self,
         ...
-        max_draft_len: int = 0,
+        max_draft_len: int = 0,
         max_total_draft_tokens: int = 0,
         ...
     ):
+        """
+        Args:
+            ...
+            max_draft_len: Number of draft layers (linear tree).
+            max_total_draft_tokens: Total draft tokens across layers (tree or linear).
+            ...
+        """

195-201: Prefer a single source-of-truth for total-draft-tokens

You store self.max_total_draft_tokens but later read model_engine.spec_config.max_total_draft_tokens in multiple places. For consistency and easier overrides, prefer using self.max_total_draft_tokens (falling back to spec_config if zero).

-        self.max_total_draft_tokens = max_total_draft_tokens
+        self.max_total_draft_tokens = max_total_draft_tokens
+        # Optional: normalize from spec_config if not provided
+        if self.max_total_draft_tokens == 0 and getattr(self.model_engine, "spec_config", None):
+            self.max_total_draft_tokens = getattr(self.model_engine.spec_config, "max_total_draft_tokens", 0)

Then use self.max_total_draft_tokens where applicable (e.g., Lines 1055-1059, 1229-1234).

tensorrt_llm/_torch/speculative/interface.py (1)

128-135: Unused argument is_spec_dec_tree

The parameter is never used. Either incorporate it into the decision or prefix with underscore to avoid ARG002 warnings until used.

-        use_chain_drafter: bool,
-        is_spec_dec_tree: bool,
+        use_chain_drafter: bool,
+        is_spec_dec_tree: bool,  # TODO: use or drop

Or:

-        is_spec_dec_tree: bool,
+        _is_spec_dec_tree: bool,
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8733e83 and 1715d06.

📒 Files selected for processing (15)
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/_util.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (6 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/drafter.py (3 hunks)
  • tensorrt_llm/_torch/speculative/drafting_loops.py (1 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (3 hunks)
  • tensorrt_llm/_torch/speculative/model_drafter.py (4 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (4 hunks)
  • tensorrt_llm/llmapi/llm_args.py (11 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/interface.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/interface.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/interface.py
🧠 Learnings (1)
📚 Learning: 2025-09-04T07:33:10.618Z
Learnt from: MrGeva
PR: NVIDIA/TensorRT-LLM#7219
File: tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py:162-168
Timestamp: 2025-09-04T07:33:10.618Z
Learning: When users explicitly provide cuda_graph_batch_sizes in TorchCudagraphCompiler, respect their choices and only sanitize the values (clamp, dedupe, sort) without forcing additional batch sizes like 1 or max_batch_size. Only add commonly-used batch sizes when falling back to the heuristic.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
tensorrt_llm/runtime/generation.py (1)
  • max_draft_tokens (1319-1322)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (3)
tensorrt_llm/llmapi/llm_args.py (1)
  • is_linear_tree (609-612)
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
  • ChainDrafter (108-170)
tensorrt_llm/_torch/speculative/utils.py (1)
  • get_num_extra_kv_tokens (226-235)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • batch_size (35-36)
tensorrt_llm/_torch/speculative/interface.py (2)
  • without_logits (62-63)
  • use_one_engine (41-42)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • enable_spec_decode (90-91)
tensorrt_llm/llmapi/llm_args.py (3)
tensorrt_llm/_utils.py (1)
  • from_dict (805-816)
tensorrt_llm/mapping.py (1)
  • from_dict (325-326)
tensorrt_llm/models/modeling_utils.py (3)
  • from_dict (253-263)
  • from_dict (325-334)
  • from_dict (487-492)
🪛 Ruff (0.14.0)
tensorrt_llm/_torch/speculative/model_drafter.py

62-62: Avoid specifying long messages outside the exception class

(TRY003)


64-64: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/speculative/interface.py

134-134: Unused method argument: is_spec_dec_tree

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (23)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

513-513: Approve max_total_draft_tokens usage

max_total_draft_tokens is always initialized when spec_config is provided, and applying +1 here correctly supports both one-model and two-model speculative decoding paths.

tensorrt_llm/_torch/pyexecutor/model_engine.py (5)

91-99: CUDA graph batch-size filter update looks correct

Using 1 + max_total_draft_tokens to bound max CUDA graph batch size is consistent with worst-case per-request token consumption. No issues spotted. Based on learnings


268-297: Draft token buffers and runtime_draft_len alignment

  • Allocating draft_tokens_cuda to batch_size * original_max_total_draft_tokens and switching runtime_draft_len to max_total_draft_tokens align with tree semantics.
  • Passing original_max_total_draft_tokens into _filter_cuda_graph_batch_sizes is consistent.

LGTM.

Also applies to: 315-320, 356-359


463-468: Good clarification on using max_draft_len for drafting loop iterations

Calling out that drafting loop iteration count must use max_draft_len (not total tokens) avoids tree/linear confusion in fused paths. Keep this invariant enforced through the file.

Please ensure all places that iterate drafting loops use max_draft_len, not max_total_draft_tokens.


160-167: Run to check PyTorchModelEngine init signature:

#!/bin/bash
rg -nP "class PyTorchModelEngine" -A5 -C0 tensorrt_llm/_torch/pyexecutor/model_engine.py

604-606: KV cache API parameters correctly represent draft tokens
Both add_dummy_requests(max_num_draft_tokens=…) and get_num_available_tokens(max_num_draft_tokens) consistently treat their argument as the number of draft tokens, not layers—no change needed.

tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

24-25: Propagate max_total_draft_tokens into SaveHiddenStatesDrafter

Storing spec_config.max_total_draft_tokens is consistent with the broader shift. No issues.

tensorrt_llm/_torch/pyexecutor/sampler.py (2)

1005-1012: Updated docstrings/comments for tree semantics

The new shape/dimension comments for new_tokens and tree sampling are accurate and helpful. LGTM.

Also applies to: 1106-1113


866-875: Approve TorchSampler Args usage and buffer sizing
All TorchSampler.Args invocations include max_total_draft_tokens, and max_tokens = args.max_total_draft_tokens + 1 is correct for tree mode.

tensorrt_llm/_torch/speculative/drafting_loops.py (1)

110-118: Constructor API update verified The ChainDrafter instantiation in py_executor_creator.py now includes both max_draft_len and max_total_draft_tokens.

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (3)

96-98: LGTM: use total draft tokens for max_possible_draft_len

Switching to engine.original_max_total_draft_tokens when spec decode is on matches the new semantics.


111-113: LGTM: graph key reflects total draft tokens for non-draft path

Using spec_config.max_total_draft_tokens in the key coordinates with padding/token-per-request logic.


347-352: No change required: engine.runtime_draft_len already equals spec_config.max_total_draft_tokens when enable_spec_decode is true.

Likely an incorrect or invalid review comment.

tensorrt_llm/_torch/pyexecutor/py_executor.py (4)

1043-1046: LGTM: pass total draft tokens to speculation gate decision

Passing max_total_draft_tokens to should_use_spec_decode aligns with new budgeting.


1055-1059: LGTM: initialize request.draft_tokens length to total draft tokens

Correctly sizes per-request draft token placeholders for scheduling.


1229-1234: LGTM: py_draft_tokens/pages sized to total draft tokens

Matches CUDA-graph padding and KV allocation expectations.


1616-1623: LGTM: ADP dummy request uses total draft tokens

max_num_draft_tokens now reflects total draft-token capacity.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2)

321-325: LGTM: concise computation of max_total_draft_tokens

Null-safe selection from speculative_config is clear and correct.


380-385: LGTM: propagate total draft tokens to sampler and executor

Threading through TorchSampler.Args and PyExecutor ensures consistent capacity.

Also applies to: 399-403

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)

360-363: LGTM: ChainDrafter receives both layer-count and total draft tokens

Constructor signature matches new drafting semantics.


476-478: LGTM: guided decoder capped by total draft tokens

Using max_total_draft_tokens for guided decoding aligns with token budgeting.

tensorrt_llm/_torch/speculative/utils.py (2)

47-63: LGTM: pass total draft tokens into Eagle3SpecMetadata

Correctly threads max_total_draft_tokens and tree flags into Eagle3 metadata.


189-196: LGTM: propagate total draft tokens to ModelDrafter

ModelDrafter receives both counts (layers, total tokens), consistent with tree drafting.

@yweng0828 yweng0828 changed the title [None][feat]add max_total_draft_tokens [None][feat] Add max_total_draft_tokens Oct 14, 2025
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21357 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21357 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16124 completed with status: 'FAILURE'

@yweng0828 yweng0828 force-pushed the yweng/add_max_total_draft_tokens branch from 1715d06 to 55a0d9e Compare October 14, 2025 13:57
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21362 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21362 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16129 completed with status: 'FAILURE'

Copy link
Collaborator

@ziyixiong-nv ziyixiong-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

Signed-off-by: Yue Weng <[email protected]>
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21406 [ run ] triggered by Bot

@yweng0828
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21419 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21406 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #16166 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21419 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit 671b48e

@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21427 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21427 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16182 completed with status: 'FAILURE'

Copy link
Collaborator

@syuoni syuoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Yue Weng <[email protected]>
@yweng0828 yweng0828 force-pushed the yweng/add_max_total_draft_tokens branch from b6ee71d to 2ffcf19 Compare October 15, 2025 05:50
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21436 [ run ] triggered by Bot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants