-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Add max_total_draft_tokens #8366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughReplaces max_draft_len usage with max_total_draft_tokens across configs, executors, engines, CUDA graph paths, samplers, and speculative components. Adds fields/params, updates constructors and flow to propagate total draft tokens, adjusts warmup/padding/runtime decisions, and updates model code to derive predicted tokens from the new config value. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor App
participant Creator as _util.create_py_executor_instance
participant Engine as ModelEngine
participant Exec as PyExecutor
participant Sampler as TorchSampler
App->>Creator: create_py_executor_instance(spec_config,…)
Note right of Creator: Derive max_total_draft_tokens (default 0 if None)
Creator->>Exec: PyExecutor(max_total_draft_tokens,…)
Creator->>Sampler: TorchSampler.Args(max_total_draft_tokens,…)
Creator->>Engine: ModelEngine(spec_config with max_total_draft_tokens)
Engine-->>Creator: instance with original_max_total_draft_tokens
Exec-->>App: executor ready
sequenceDiagram
autonumber
participant Exec as PyExecutor
participant Drafter as Drafter/ModelDrafter
participant Cuda as CudaGraphRunner
participant Eng as ModelEngine
Exec->>Drafter: should_use_spec_decode(max_total_draft_tokens,…)
alt Spec decode enabled
Exec->>Eng: runtime_draft_len = max_total_draft_tokens
Exec->>Cuda: get_graph_key(max_total_draft_tokens,…)
Cuda-->>Exec: key reflecting total draft tokens
Exec-->>Exec: allocate/pad draft tokens up to total
else Autoregressive
Exec-->>Exec: set draft tokens to 0
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (19)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
704-723
: Bug: draft_lengths logic overwrites “0” capture and misuses total-tokens for CDL loop iterationsTwo issues:
- Draft lengths for CUDA graph warmup (non-draft model) overwrite a previously appended 0, losing the ability to replay graphs with spec decode disabled at runtime (contradicts the comment).
- For CDL/Eagle3 path, draft_len represents number of drafting iterations. It should use original_max_draft_len, not original_max_total_draft_tokens.
Fix with this diff:
@@ - if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance( - spec_resource_manager, Eagle3ResourceManager): - # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop. - draft_lengths.append(self.original_max_total_draft_tokens) + if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance( + spec_resource_manager, Eagle3ResourceManager): + # The CDL path uses draft_len to mean "number of drafting iterations" + draft_lengths.append(self.original_max_draft_len) else: draft_lengths.append(self.max_draft_len) @@ - if (self.max_total_draft_tokens > 0 - and not self.spec_config.spec_dec_mode.use_one_engine() - # Assume that speculation is always on if the user didn't give us a max_concurrency - # value. This will save on memory. - and self.spec_config.max_concurrency is not None): - draft_lengths.append(0) - draft_lengths = [self.max_total_draft_tokens] + # Always capture the current speculative setting. Optionally also capture draft_len=0 + # to allow disabling speculative decoding at runtime in the two-engine path. + draft_lengths = [self.max_total_draft_tokens] + if (self.max_total_draft_tokens > 0 + and not self.spec_config.spec_dec_mode.use_one_engine() + and self.spec_config.max_concurrency is not None): + draft_lengths.append(0)
1-5
: Missing NVIDIA Apache-2.0 headerPer coding guidelines, prepend the NVIDIA Apache-2.0 header with current year to this source file.
Apply at file top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
1-5
: Missing NVIDIA Apache-2.0 headerPlease add the required header at the top of the file.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/speculative/drafting_loops.py (1)
1-5
: Missing NVIDIA Apache-2.0 headerAdd the required header at the top.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/speculative/drafter.py (3)
1-5
: Missing NVIDIA Apache-2.0 headerAdd the required header at the top.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
29-55
: Tests: rename max_draft_len to max_total_draft_tokens
In tests/unittest/_torch/speculative/test_dynamic_spec_decode.py, update the mock signature and all calls (e.g. lines 75, 198–234) to usemax_total_draft_tokens
instead ofmax_draft_len
.
65-71
: Pad draft tokens to max_total_draft_tokens in CUDA-graph padding
In drafter.py (57–71), pad_draft_tokens_for_cuda_graph uses self.max_draft_len. For tree-based speculation (where max_total_draft_tokens > max_draft_len) this under-pads py_draft_tokens and breaks CUDA-graph static shapes—use max_total_draft_tokens instead or guard by tree mode.tensorrt_llm/_torch/pyexecutor/sampler.py (1)
1-5
: Missing NVIDIA Apache-2.0 headerPlease add the required header at the top.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
224-226
: Bug: position_ids becomes a tuple due to trailing commaThe trailing comma makes position_ids a tuple (tensor,) instead of a tensor, likely breaking consumers.
- sliced_static_tensors["position_ids"] = self.shared_static_tensors[ - "position_ids"][:, :, :num_tokens_for_capture], + sliced_static_tensors["position_ids"] = self.shared_static_tensors[ + "position_ids"][:, :, :num_tokens_for_capture]tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
tensorrt_llm/_torch/speculative/interface.py (1)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
tensorrt_llm/_torch/speculative/utils.py (4)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the NVIDIA Apache-2.0 copyright header (current year) at the top of this source file. As per coding guidelines.
81-93
: Bug: duplicate keyword argument max_total_draft_tokensmax_total_draft_tokens is passed twice; Python will raise “multiple values for keyword”. Decide the intended value (likely 1 for SaveHiddenStates) and keep only one.
- return Eagle3SpecMetadata( - max_draft_len=spec_config.max_draft_len, - max_total_draft_tokens=spec_config.max_total_draft_tokens, + return Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, layers_to_capture=spec_config.eagle3_layers_to_capture, max_total_draft_tokens=1, )
94-101
: Missing max_total_draft_tokens when constructing SpecMetadataSpecMetadata now requires max_total_draft_tokens. Include it for these modes.
- return SpecMetadata( + return SpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, )
134-144
: Include the newis_spec_dec_tree
argument in attention_need_spec_dec_mode calls
In tensorrt_llm/_torch/pyexecutor/model_engine.py:2166, the call tospec_dec_mode.attention_need_spec_dec_mode
only passes four parameters—append the new boolean flag at the end to match the updated signature.tensorrt_llm/_torch/speculative/model_drafter.py (1)
473-474
: Critical: AttributeError onself.max_draft_tokens
.Line 473 references
self.max_draft_tokens
, but the constructor (lines 74-75) no longer sets this attribute. After the parameter rename, the attributes are nowself.max_draft_len
andself.max_total_draft_tokens
. This code will fail at runtime withAttributeError: 'ModelDrafter' object has no attribute 'max_draft_tokens'
.The same issue exists at lines 567, 650, and 726. Based on context, these should likely use
self.max_total_draft_tokens
since they allocate or iterate over draft token buffers.Apply this diff to fix line 473:
new_tokens_lens = torch.ones(batch_size, device=device) next_draft_tokens = torch.zeros(batch_size, - self.max_draft_tokens, + self.max_total_draft_tokens, device=device)Similar fixes are needed at:
- Line 567:
for token_idx in range(self.max_draft_tokens):
→range(self.max_total_draft_tokens):
- Line 650:
for i in range(self.max_draft_tokens - 1):
→range(self.max_total_draft_tokens - 1):
- Line 726:
draft_length=self.max_draft_tokens,
→draft_length=self.max_total_draft_tokens,
🧹 Nitpick comments (4)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
106-108
: Clarify/typo in commentNit: use plural “max_total_draft_tokens” and tighten wording.
- # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'. - # Because we will pad the input to 'max_draft_len' length for the first draft layer. + # For the first draft in tree decoding, draft_len equals 'max_draft_len' (number of layers), not 'max_total_draft_tokens', + # because we pad inputs to 'max_draft_len' for the first draft layer.tensorrt_llm/_torch/pyexecutor/py_executor.py (2)
150-171
: Constructor: add brief docstring for new argNew parameter max_total_draft_tokens is public API. Add a short docstring entry.
class PyExecutor: def __init__(self, ... - max_draft_len: int = 0, + max_draft_len: int = 0, max_total_draft_tokens: int = 0, ... ): + """ + Args: + ... + max_draft_len: Number of draft layers (linear tree). + max_total_draft_tokens: Total draft tokens across layers (tree or linear). + ... + """
195-201
: Prefer a single source-of-truth for total-draft-tokensYou store self.max_total_draft_tokens but later read model_engine.spec_config.max_total_draft_tokens in multiple places. For consistency and easier overrides, prefer using self.max_total_draft_tokens (falling back to spec_config if zero).
- self.max_total_draft_tokens = max_total_draft_tokens + self.max_total_draft_tokens = max_total_draft_tokens + # Optional: normalize from spec_config if not provided + if self.max_total_draft_tokens == 0 and getattr(self.model_engine, "spec_config", None): + self.max_total_draft_tokens = getattr(self.model_engine.spec_config, "max_total_draft_tokens", 0)Then use self.max_total_draft_tokens where applicable (e.g., Lines 1055-1059, 1229-1234).
tensorrt_llm/_torch/speculative/interface.py (1)
128-135
: Unused argument is_spec_dec_treeThe parameter is never used. Either incorporate it into the decision or prefix with underscore to avoid ARG002 warnings until used.
- use_chain_drafter: bool, - is_spec_dec_tree: bool, + use_chain_drafter: bool, + is_spec_dec_tree: bool, # TODO: use or dropOr:
- is_spec_dec_tree: bool, + _is_spec_dec_tree: bool,
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
(2 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/_util.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(11 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(4 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py
(2 hunks)tensorrt_llm/_torch/speculative/drafter.py
(3 hunks)tensorrt_llm/_torch/speculative/drafting_loops.py
(1 hunks)tensorrt_llm/_torch/speculative/interface.py
(3 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(4 hunks)tensorrt_llm/llmapi/llm_args.py
(11 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/models/modeling_deepseekv3.py
tensorrt_llm/_torch/speculative/drafter.py
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
tensorrt_llm/_torch/speculative/drafting_loops.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/interface.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/models/modeling_deepseekv3.py
tensorrt_llm/_torch/speculative/drafter.py
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
tensorrt_llm/_torch/speculative/drafting_loops.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/interface.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/models/modeling_deepseekv3.py
tensorrt_llm/_torch/speculative/drafter.py
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
tensorrt_llm/_torch/speculative/drafting_loops.py
tensorrt_llm/_torch/pyexecutor/sampler.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/interface.py
🧠 Learnings (1)
📚 Learning: 2025-09-04T07:33:10.618Z
Learnt from: MrGeva
PR: NVIDIA/TensorRT-LLM#7219
File: tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py:162-168
Timestamp: 2025-09-04T07:33:10.618Z
Learning: When users explicitly provide cuda_graph_batch_sizes in TorchCudagraphCompiler, respect their choices and only sanitize the values (clamp, dedupe, sort) without forcing additional batch sizes like 1 or max_batch_size. Only add commonly-used batch sizes when falling back to the heuristic.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens
(1319-1322)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (3)
tensorrt_llm/llmapi/llm_args.py (1)
is_linear_tree
(609-612)tensorrt_llm/_torch/speculative/drafting_loops.py (1)
ChainDrafter
(108-170)tensorrt_llm/_torch/speculative/utils.py (1)
get_num_extra_kv_tokens
(226-235)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
batch_size
(35-36)tensorrt_llm/_torch/speculative/interface.py (2)
without_logits
(62-63)use_one_engine
(41-42)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
enable_spec_decode
(90-91)
tensorrt_llm/llmapi/llm_args.py (3)
tensorrt_llm/_utils.py (1)
from_dict
(805-816)tensorrt_llm/mapping.py (1)
from_dict
(325-326)tensorrt_llm/models/modeling_utils.py (3)
from_dict
(253-263)from_dict
(325-334)from_dict
(487-492)
🪛 Ruff (0.14.0)
tensorrt_llm/_torch/speculative/model_drafter.py
62-62: Avoid specifying long messages outside the exception class
(TRY003)
64-64: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/speculative/interface.py
134-134: Unused method argument: is_spec_dec_tree
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (23)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
513-513
: Approve max_total_draft_tokens usage
max_total_draft_tokens
is always initialized whenspec_config
is provided, and applying+1
here correctly supports both one-model and two-model speculative decoding paths.tensorrt_llm/_torch/pyexecutor/model_engine.py (5)
91-99
: CUDA graph batch-size filter update looks correctUsing 1 + max_total_draft_tokens to bound max CUDA graph batch size is consistent with worst-case per-request token consumption. No issues spotted. Based on learnings
268-297
: Draft token buffers and runtime_draft_len alignment
- Allocating draft_tokens_cuda to batch_size * original_max_total_draft_tokens and switching runtime_draft_len to max_total_draft_tokens align with tree semantics.
- Passing original_max_total_draft_tokens into _filter_cuda_graph_batch_sizes is consistent.
LGTM.
Also applies to: 315-320, 356-359
463-468
: Good clarification on using max_draft_len for drafting loop iterationsCalling out that drafting loop iteration count must use max_draft_len (not total tokens) avoids tree/linear confusion in fused paths. Keep this invariant enforced through the file.
Please ensure all places that iterate drafting loops use max_draft_len, not max_total_draft_tokens.
160-167
: Run to check PyTorchModelEngine init signature:#!/bin/bash rg -nP "class PyTorchModelEngine" -A5 -C0 tensorrt_llm/_torch/pyexecutor/model_engine.py
604-606
: KV cache API parameters correctly represent draft tokens
Bothadd_dummy_requests(max_num_draft_tokens=…)
andget_num_available_tokens(max_num_draft_tokens)
consistently treat their argument as the number of draft tokens, not layers—no change needed.tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
24-25
: Propagate max_total_draft_tokens into SaveHiddenStatesDrafterStoring spec_config.max_total_draft_tokens is consistent with the broader shift. No issues.
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
1005-1012
: Updated docstrings/comments for tree semanticsThe new shape/dimension comments for new_tokens and tree sampling are accurate and helpful. LGTM.
Also applies to: 1106-1113
866-875
: Approve TorchSampler Args usage and buffer sizing
AllTorchSampler.Args
invocations includemax_total_draft_tokens
, andmax_tokens = args.max_total_draft_tokens + 1
is correct for tree mode.tensorrt_llm/_torch/speculative/drafting_loops.py (1)
110-118
: Constructor API update verified The ChainDrafter instantiation in py_executor_creator.py now includes both max_draft_len and max_total_draft_tokens.tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (3)
96-98
: LGTM: use total draft tokens for max_possible_draft_lenSwitching to engine.original_max_total_draft_tokens when spec decode is on matches the new semantics.
111-113
: LGTM: graph key reflects total draft tokens for non-draft pathUsing spec_config.max_total_draft_tokens in the key coordinates with padding/token-per-request logic.
347-352
: No change required:engine.runtime_draft_len
already equalsspec_config.max_total_draft_tokens
whenenable_spec_decode
is true.Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
1043-1046
: LGTM: pass total draft tokens to speculation gate decisionPassing max_total_draft_tokens to should_use_spec_decode aligns with new budgeting.
1055-1059
: LGTM: initialize request.draft_tokens length to total draft tokensCorrectly sizes per-request draft token placeholders for scheduling.
1229-1234
: LGTM: py_draft_tokens/pages sized to total draft tokensMatches CUDA-graph padding and KV allocation expectations.
1616-1623
: LGTM: ADP dummy request uses total draft tokensmax_num_draft_tokens now reflects total draft-token capacity.
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2)
321-325
: LGTM: concise computation of max_total_draft_tokensNull-safe selection from speculative_config is clear and correct.
380-385
: LGTM: propagate total draft tokens to sampler and executorThreading through TorchSampler.Args and PyExecutor ensures consistent capacity.
Also applies to: 399-403
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
360-363
: LGTM: ChainDrafter receives both layer-count and total draft tokensConstructor signature matches new drafting semantics.
476-478
: LGTM: guided decoder capped by total draft tokensUsing max_total_draft_tokens for guided decoding aligns with token budgeting.
tensorrt_llm/_torch/speculative/utils.py (2)
47-63
: LGTM: pass total draft tokens into Eagle3SpecMetadataCorrectly threads max_total_draft_tokens and tree flags into Eagle3 metadata.
189-196
: LGTM: propagate total draft tokens to ModelDrafterModelDrafter receives both counts (layers, total tokens), consistent with tree drafting.
/bot run --disable-fail-fast |
PR_Github #21357 [ run ] triggered by Bot |
PR_Github #21357 [ run ] completed with state |
Signed-off-by: Yue Weng <[email protected]>
1715d06
to
55a0d9e
Compare
/bot run --disable-fail-fast |
PR_Github #21362 [ run ] triggered by Bot |
PR_Github #21362 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Signed-off-by: Yue Weng <[email protected]>
/bot run --disable-fail-fast |
PR_Github #21406 [ run ] triggered by Bot |
/bot kill |
PR_Github #21419 [ kill ] triggered by Bot |
PR_Github #21406 [ run ] completed with state |
PR_Github #21419 [ kill ] completed with state |
/bot run --disable-fail-fast |
PR_Github #21427 [ run ] triggered by Bot |
PR_Github #21427 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Yue Weng <[email protected]>
b6ee71d
to
2ffcf19
Compare
/bot run --disable-fail-fast |
PR_Github #21436 [ run ] triggered by Bot |
Summary by CodeRabbit
Description
The purpose of this PR is to add a
max_total_draft_tokens
field to support draft token trees.max_total_draft_tokens
indicates the total number of draft tokens generated after passing through the draft layers. The existingmax_draft_len
will still exist, indicating the number of draft layers.For linear trees,
max_draft_len
will be equal tomax_total_draft_tokens
(each draft layer will only generate one draft token).Note: This has not been modified or adapted for the one-model path, as
max_total_draft_tokens
currently only works on the two-model path.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.