Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,7 @@ class AttentionMetadata:
default_factory=AttentionRuntimeFeatures)

# The number of tokens in each rank.
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
default=None,
repr=False)
all_rank_num_tokens: Optional[List[int]]
# The max number of tokens among all ranks.
all_rank_max_num_tokens: Optional[int] = None
all_rank_num_tokens: Optional[List[int]] = None

# These fields are set when changing seq_lens and _num_contexts to avoid computation
# during execution. If the calculation happens during execution, torch compile treats it
Expand Down Expand Up @@ -167,16 +162,6 @@ def on_update(self):
elif self._seq_lens is not None:
self._num_tokens = self._seq_lens.sum().item()

@property
def all_rank_num_tokens(self) -> Optional[List[int]]:
return self._all_rank_num_tokens

@all_rank_num_tokens.setter
def all_rank_num_tokens(self, value: Optional[List[int]]):
value = value if value is not AttentionMetadata.all_rank_num_tokens else None
self._all_rank_num_tokens = value
self.all_rank_max_num_tokens = max(value) if value is not None else None

@property
def seq_lens(self) -> Optional[torch.Tensor]:
return self._seq_lens
Expand Down
52 changes: 35 additions & 17 deletions tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch

from tensorrt_llm._torch.utils import (fp4_utils,
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
next_positive_power_of_2)
Expand Down Expand Up @@ -269,6 +269,31 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
return kernel_runner(inputs, tactic=best_tactic)


def fp4_block_scale_fake_output_without_finalize(
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
num_experts: int,
top_k: int,
routing_bias: Optional[torch.Tensor],
):
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * (2 if isinstance(
hidden_states, Fp4QuantizedTensor) else 1)

tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)

expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]


@fp4_block_scale_moe_runner.register_fake
def _(
routing_logits,
Expand All @@ -293,27 +318,20 @@ def _(
routing_method_type,
do_finalize,
) -> List[torch.Tensor]:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
if do_finalize:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
return [
hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)
]

tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)

expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]
return fp4_block_scale_fake_output_without_finalize(
hidden_states,
num_experts,
top_k,
routing_bias,
)


@dataclass(frozen=True)
Expand Down
9 changes: 1 addition & 8 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,7 @@ def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig:
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)

def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, all_rank_max_num_tokens,
do_finalize):
all_rank_num_tokens, do_finalize):
# max-throughput
use_dp_padding = False
if self.use_dp and self.mapping.tp_size > 1:
Expand All @@ -568,7 +567,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
do_finalize=do_finalize,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
)

Expand All @@ -579,7 +577,6 @@ def forward(
hidden_states: torch.Tensor,
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
all_rank_num_tokens: Optional[list[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
final_all_reduce_params: Optional[AllReduceParams] = None,
do_finalize: Optional[bool] = True,
) -> torch.Tensor:
Expand All @@ -598,7 +595,6 @@ def _compute_routed_output():
routed_output = self.compute_routed_output(hidden_states,
hidden_states_fp4,
all_rank_num_tokens,
all_rank_max_num_tokens,
do_finalize)
return routed_output

Expand Down Expand Up @@ -840,7 +836,6 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
Expand Down Expand Up @@ -1028,7 +1023,6 @@ def forward(
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
**kwargs,
) -> torch.Tensor:

Expand Down Expand Up @@ -1087,7 +1081,6 @@ def norm_hidden():
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
Expand Down
11 changes: 4 additions & 7 deletions tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def forward_attn_dp(

# Get attention_dp parameters
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens

if self.mapping.tp_size > 1 and all_rank_num_tokens is not None:
if (isinstance(self.experts, (TRTLLMGenFusedMoE, TritonFusedMoE))):
Expand All @@ -276,12 +275,10 @@ def forward_attn_dp(

# Let CutlassFusedMoE handle allgather internally
# Pass the normalized tensor (t) as input to experts, not x
expert_output = self.experts(
x=t,
router_logits=g,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=False)
expert_output = self.experts(x=t,
router_logits=g,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)

expert_output = expert_output.view(orig_shape)
return expert_output, residual
Expand Down
18 changes: 6 additions & 12 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,32 +315,27 @@ def __init__(
self.aux_stream = aux_stream

def compute_routed_output(self, hidden_states, all_rank_num_tokens,
all_rank_max_num_tokens,
cutlass_min_latency_mode):
router_logits = self.router(hidden_states)
routed_output = self.experts(
hidden_states,
router_logits,
do_finalize=not cutlass_min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=False)
routed_output = self.experts(hidden_states,
router_logits,
do_finalize=not cutlass_min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
return routed_output

def forward(
self,
hidden_states: torch.Tensor,
all_rank_num_tokens=None,
all_rank_max_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
cutlass_min_latency_mode: Optional[bool] = False,
) -> torch.Tensor:
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
# This design is mainly for low latency use case. Need to improve for max throughput use case.
fn0 = lambda: self.shared_expert(hidden_states)
fn1 = lambda: self.compute_routed_output(
hidden_states, all_rank_num_tokens, all_rank_max_num_tokens,
cutlass_min_latency_mode)
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
shared_output, routed_output = maybe_execute_in_parallel(
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
if cutlass_min_latency_mode:
Expand Down Expand Up @@ -542,7 +537,6 @@ def forward(
hidden_states = self.feed_forward(
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not self.disable_feed_forward_allreduce),
cutlass_min_latency_mode=cutlass_min_latency_mode,
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/models/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=False)
return final_hidden_states

Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def forward(
hidden_states = hidden_states.view(-1, self.hidden_dim)
use_dp_padding = False
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens

if not do_finalize:
assert not self.enable_attention_dp
Expand All @@ -144,7 +143,6 @@ def forward(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,11 @@ def forward(
hidden_states = hidden_states.view(-1, self.hidden_dim)

all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=False)

shared_expert_output = self.shared_expert(hidden_states)
Expand Down
10 changes: 5 additions & 5 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
layer_idx=layer_idx,
)

# Store original hidden size before any potential padding
Expand All @@ -96,8 +97,6 @@ def __init__(
self.intermediate_size_per_partition = (
(self.intermediate_size_per_partition + 127) // 128) * 128

self.layer_idx = layer_idx

self.num_slots = self.num_experts
self.expert_size_per_partition = self.num_experts // self.ep_size
self.initial_global_assignments = [
Expand Down Expand Up @@ -449,15 +448,16 @@ def split_chunk(self, split_token_num: int, split_num_chunks: int):
split_num_chunks - val_mod)
return split_chunk_size_list

def forward(
def forward_impl(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
*,
do_finalize: bool = True, # used by other MoE backends
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
**kwargs,
) -> torch.Tensor:
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
if self.use_dp and self.parallel_size > 1:
Expand All @@ -472,7 +472,7 @@ def forward(
1) // self.moe_max_num_tokens

if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens
Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,15 +637,16 @@ def forward_chunk(

return final_hidden_states

def forward(
def forward_impl(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
*,
do_finalize: bool = True, # used by other MoE backends
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
**kwargs,
) -> torch.Tensor:
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
if self.use_dp and self.parallel_size > 1:
Expand All @@ -663,7 +664,7 @@ def forward(
1) // self.moe_max_num_tokens

if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,7 @@ def __init__(
reduce_results=reduce_results,
model_config=model_config,
weight_loading_mode=weight_loading_mode,
layer_idx=layer_idx,
)
if not IS_TRITON_KERNELS_AVAILABLE:
raise ImportError("Triton kernels are not available.")
Expand Down Expand Up @@ -1359,10 +1360,11 @@ def create_weights(self):

self._weights_created = True

def forward(
def forward_impl(
self,
x: torch.Tensor,
router_logits: torch.Tensor,
*,
do_finalize: bool = True,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
Expand Down
Loading