Skip to content
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ transforms:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
sharding_dims: ['tp', 'ep', 'bmm']
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
)

sharding_dims: List[str] = Field(
default=["tp", "ep", "bmm"],
default=["tp", "ep", "dp"],
description="The sharding methods to apply by the heuristic sharding stage.",
)

Expand Down
69 changes: 64 additions & 5 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class ShardingTransformConfig(TransformConfig):

simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = Field(default=False)
support_partial_config: bool = Field(default=False)
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])

Expand Down Expand Up @@ -185,6 +186,9 @@ def _apply(
else ShardingConfigSource.UNKNOWN
)
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
shared_config.sharding_config.sharding_dims = self.config.sharding_dims

shared_config.sharding_config.use_sharding_from_factory = (
self.config.use_sharding_from_factory
)
Expand All @@ -200,8 +204,6 @@ def _apply(
factory_info = detect_sharding_from_factory_config(gm, sharding_config)
return gm, factory_info

shared_config.sharding_config.sharding_dims = self.config.sharding_dims

ad_logger.info(
f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}"
)
Expand Down Expand Up @@ -338,8 +340,39 @@ def detect_sharding_from_factory_config(
# TODO: Sequence parallelism is not supported yet.
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
elif "local" in config:
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
# Check if this applies to shared experts in EP parallelism.
# If yes, apply the TP col-row shard.
if "shared" in module_name:
col_row_action = config.replace("local_", "")
if col_row_action == "colwise":
sharding_config.tp_transforms.append(
TPShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.COLUMN,
rank=rank,
world_size=world_size,
dist_op=None,
min_local_shape=min_local_shape,
)
)
elif col_row_action == "rowwise":
sharding_config.tp_transforms.append(
TPShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.ROW,
rank=rank,
world_size=world_size,
dist_op="all_reduce",
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
else:
ad_logger.warning("Invalid sharding config. Skipping.")
else:
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")

elif "gather" in config:
# Simple shard (row + all_gather)
sharding_config.tp_transforms.append(
Expand All @@ -362,9 +395,35 @@ def detect_sharding_from_factory_config(
f"Applied {num_shards} TP shards (simple: {num_simple_shards}, "
f"row-col pattern: {num_row_col_shards})"
)

num_matches = len(sharding_config.tp_transforms)

if sharding_config.support_partial_config:
ad_logger.info(
f"Partial factory config applied only for TP. "
f"Applying heuristics for {sharding_config.sharding_dims}."
)

# run EP sharding across ranks
if "ep" in sharding_config.sharding_dims:
ep_info = detect_ep_shard(gm, sharding_config)
else:
ep_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

# run BMM sharding across ranks
if "bmm" in sharding_config.sharding_dims:
dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)
else:
dp_bmm_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
num_matches += ep_info.num_matches + dp_bmm_info.num_matches

return TransformInfo(
skipped=False,
num_matches=len(sharding_config.tp_transforms),
num_matches=num_matches,
is_clean=False,
has_valid_shapes=False,
)
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ class ShardingConfig(BaseModel):
predefined_config: Optional[Dict[str, Any]] = None
simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = False
support_partial_config: bool = False
sharding_dims: List[str] = Field(default_factory=list)
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
Expand Down Expand Up @@ -781,7 +782,7 @@ def validate_config(self) -> bool:
tp_plan = self.predefined_config["tp_plan"]

values = set(tp_plan.values())
allowed_values = {
supported_modes = {
"colwise", # row split and no collective
"rowwise", # column split and all-reduce
"gather", # simple shard (row + all_gather)
Expand All @@ -793,7 +794,7 @@ def validate_config(self) -> bool:
# "local_packed_rowwise",
# "local",
}
if not values.issubset(allowed_values):
if not self.support_partial_config and not values.issubset(supported_modes):
ad_logger.warning("Sharding config contains invalid values. Skipping.")
# invalidate the config
self.predefined_config = {}
Expand Down