-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
draft AFD implementation for step3 #25162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a preliminary implementation of Step3 AFD (Attention FFN Disaggregation), a significant new feature for distributed inference. The changes are extensive, adding new configurations, communication connectors, and modifying core model execution logic. While the implementation lays a solid foundation, I've identified several critical issues related to correctness in distributed settings, maintainability, and configuration that should be addressed. These include potential race conditions or deadlocks due to incorrect distributed logic, CUDA graph caching issues, hardcoded values that limit portability, and significant code duplication. Addressing these points will improve the robustness and usability of this new feature.
factors: list[Any] = [ | ||
self.afd_connector, | ||
self.afd_role, | ||
self.num_afd_stages, | ||
self.num_attention_servers, | ||
self.num_ffn_servers, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute_hash
method for AFDConfig
omits several fields that could affect the computation graph, such as afd_server_rank
and afd_extra_config
. If these fields alter the model's execution path, their omission can lead to CUDA graph cache collisions, causing incorrect results or crashes. The afd_extra_config
dictionary should be hashed in a deterministic way (e.g., by sorting keys) to ensure a stable hash.
factors: list[Any] = [ | |
self.afd_connector, | |
self.afd_role, | |
self.num_afd_stages, | |
self.num_attention_servers, | |
self.num_ffn_servers, | |
] | |
factors: list[Any] = [ | |
self.afd_connector, | |
self.afd_role, | |
self.num_afd_stages, | |
self.num_attention_servers, | |
self.num_ffn_servers, | |
self.afd_server_rank, | |
] | |
if self.afd_extra_config: | |
factors.append(json.dumps(self.afd_extra_config, sort_keys=True)) | |
def recv_ffn_output( | ||
self, | ||
handle: Any, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of recv_ffn_output
in this abstract base class (handle: Any
) does not match the implementations in DummyAFDConnector
and StepMeshAFDConnector
, which use timeout_ms: Optional[float] = None
. This violates the Liskov Substitution Principle and will lead to runtime errors. The call sites in step3_text.py
also call this method without arguments. The signatures across the base class and all implementations should be consistent.
def recv_ffn_output( | |
self, | |
handle: Any, | |
) -> torch.Tensor: | |
def recv_ffn_output( | |
self, | |
timeout_ms: Optional[float] = None, | |
) -> torch.Tensor: |
self.scheduler_process = subprocess.Popen( | ||
[ | ||
"python", | ||
"-c", | ||
"import torch; import fserver_lib as ps; import os; " | ||
'os.environ["DMLC_ROLE"] = "scheduler"; ' | ||
'os.environ["DMLC_INTERFACE"] = "brainpf_bond0"; ' | ||
"ps.init(); ps.stop()", | ||
], | ||
env=os.environ.copy(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The network interface DMLC_INTERFACE
is hardcoded to "brainpf_bond0"
for the scheduler subprocess. This is specific to a particular environment and will cause the scheduler to fail in other environments where this interface does not exist. This should be made configurable or use a more general default like "auto"
which is used elsewhere in this file.
self.scheduler_process = subprocess.Popen(
[
"python",
"-c",
"import torch; import fserver_lib as ps; import os; "
'os.environ["DMLC_ROLE"] = "scheduler"; '
"ps.init(); ps.stop()",
],
env=os.environ.copy(),
)
def send_attn_output( | ||
self, | ||
hidden_states: torch.Tensor, | ||
metadata: AFDConnectorMetadata, | ||
) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The send_attn_output
method is defined in the base class AFDConnectorBase
to return a handle of type Any
. However, this implementation does not return any value. This violates the interface contract and can lead to NoneType
errors if the caller expects a handle. The method should return the event
handle created by ps.push_pull
.
stage_num_reqs = stage_end_req - stage_start_req | ||
stage_num_actual_tokens = stage_end_token - stage_start_token | ||
stage_max_seq_len = int( | ||
seq_lens_cpu[stage_start_req:stage_end_req].max()) | ||
|
||
stage_max_query_len = min(max_query_len, stage_num_actual_tokens) | ||
|
||
if stage_num_actual_tokens == 0 or stage_num_reqs == 0: | ||
stage_metadatas.append(None) | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential RuntimeError
here. If a stage has no requests (stage_num_reqs == 0
), seq_lens_cpu[stage_start_req:stage_end_req]
will be an empty tensor. Calling .max()
on an empty tensor will raise an error. The check for stage_num_reqs == 0
should be performed before attempting to calculate stage_max_seq_len
.
stage_num_reqs = stage_end_req - stage_start_req | |
stage_num_actual_tokens = stage_end_token - stage_start_token | |
stage_max_seq_len = int( | |
seq_lens_cpu[stage_start_req:stage_end_req].max()) | |
stage_max_query_len = min(max_query_len, stage_num_actual_tokens) | |
if stage_num_actual_tokens == 0 or stage_num_reqs == 0: | |
stage_metadatas.append(None) | |
continue | |
stage_num_reqs = stage_end_req - stage_start_req | |
stage_num_actual_tokens = stage_end_token - stage_start_token | |
if stage_num_actual_tokens == 0 or stage_num_reqs == 0: | |
stage_metadatas.append(None) | |
continue | |
stage_max_seq_len = int( | |
seq_lens_cpu[stage_start_req:stage_end_req].max()) | |
stage_max_query_len = min(max_query_len, stage_num_actual_tokens) |
tp_world_size = get_tensor_model_parallel_world_size() | ||
if tp_world_size > 1: | ||
# Handle TP case: all-gather tensors from all TP ranks | ||
gathered_hidden_states = tensor_model_parallel_all_gather( | ||
hidden_states, dim=0) | ||
ffn_output = self.model.compute_ffn_output(current_layer_idx, | ||
gathered_hidden_states) | ||
|
||
# Extract the output corresponding to current rank | ||
start_idx = hidden_states.shape[ | ||
0] * get_tensor_model_parallel_rank() | ||
end_idx = start_idx + hidden_states.shape[0] | ||
rank_ffn_output = ffn_output[start_idx:end_idx, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for handling tensor parallelism appears to be incorrect, similar to the issue in _execute_eager_mode
. The tensor_model_parallel_all_gather
will unnecessarily replicate the hidden_states
tensor across TP ranks, leading to incorrect computations. The FFN layers themselves should manage the sharding. The input hidden_states
should be passed directly to self.model.compute_ffn_output
.
# The input hidden_states is replicated on all TP ranks.
# The FFN layers with tensor parallelism will handle sharding internally.
rank_ffn_output = self.model.compute_ffn_output(
current_layer_idx, hidden_states)
afd_metadata = forward_context.afd_metadata | ||
if afd_metadata is not None: | ||
afd_stage_idx = afd_metadata.afd_stage_idx | ||
if afd_stage_idx < len(attn_metadata): | ||
attn_metadata = attn_metadata[afd_stage_idx] | ||
else: | ||
attn_metadata = None # padding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to extract attn_metadata
based on afd_metadata
is duplicated in at least five places in this file (here, in the other forward
path, maybe_save_kv_layer_to_connector
, unified_attention
, and unified_attention_with_output
). This much code duplication is a significant maintainability risk; a bug fix or logic change would need to be applied in all locations, which is error-prone. I recommend refactoring this logic into a helper function to centralize it and improve code clarity.
layer_idx=-1, # Extract from comm_id | ||
stage_idx=-1, # Extract from comm_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using hardcoded placeholder values (-1
) for layer_idx
and stage_idx
is risky. The TODO
comment indicates this is incomplete. If downstream code does not properly handle these negative indices, it could lead to indexing errors or silent correctness issues. This logic should be fully implemented to extract the correct indices from comm_id
before this feature is considered complete.
try: | ||
if len(self.events) > 0: | ||
event, metadata = self.events.popleft() | ||
ps.wait(event, timeout_ms=50000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
if self.profiler: | ||
self.profiler.start() | ||
for _ in range(1000): # FIXME: hardcoded profiler iterations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def loader() -> type[AFDConnectorBase]: | ||
module = importlib.import_module(module_path) | ||
return getattr(module, class_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can directly use vllm.utils.resolve_obj_by_qualname
.
slot_mapping=stage_slot_mapping, | ||
use_cascade=False, | ||
common_prefix_len=0, | ||
scheduler_metadata=stage_scheduler_metadata, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if aot_schedule
is false, stage_scheduler_metadata
and FlashAttentionMetadata
can not be created.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Configs no longer live in vllm/config/__init__.py
Documentation preview: https://vllm--25162.org.readthedocs.build/en/25162/ |
Hello, I'd like to ask: how does StepMesh implement CUDA Graph? During graph replay, Python code isn't executed—so how is the PushPullWorker thread in StepMesh launched? |
Now stepmesh connector operation is not captured inside the cuda graph and piecewise cuda graph is used. |
causal=causal) | ||
return attn_metadata | ||
|
||
def _init_stage_buffers(self, vllm_config: VllmConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please see:
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 1388 to 1404 in ae9d0e7
if ubatch_slices is not None: | |
common_attn_metadata_list = split_attn_metadata( | |
ubatch_slices, common_attn_metadata | |
) | |
for ubid, common_attn_metadata in enumerate( | |
common_attn_metadata_list | |
): | |
attn_metadata_i = attn_group.get_metadata_builder( | |
ubatch_id=ubid | |
).build( | |
common_prefix_len=common_prefix_len, | |
common_attn_metadata=common_attn_metadata, | |
) | |
for layer_name in kv_cache_group_spec.layer_names: | |
assert type(attn_metadata) is list | |
attn_metadata[ubid][layer_name] = attn_metadata_i | |
else: |
1. Attn | ||
|
||
``` | ||
vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be for vllm serve?
Purpose
Desgin doc: https://docs.google.com/document/d/1GS2g8df7sdPmDvysmsURXN7xDDwnJfM6ERkBryUiTEA/edit?tab=t.0#heading=h.g8s3tkkthjdk
Step Paper: https://arxiv.org/abs/2507.19427
This is the preliminary implementation of Step3 AFD. It currently only supports the StepMesh connector and the Step3 model. In the future, the community will help expand the connector and add support for the DeepSeek V3 model as mentioned in RFC
The current CUDA Graph implementation for AFD still requires optimization. At present, it involves intrusive modifications to each model and is not compatible with the existing CudaGraphWrapper implementation. Welcome everyone to join the discussion on finding a more elegant solution.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.