Skip to content

Conversation

Oliver-ss
Copy link
Contributor

@Oliver-ss Oliver-ss commented Sep 18, 2025

Purpose

Desgin doc: https://docs.google.com/document/d/1GS2g8df7sdPmDvysmsURXN7xDDwnJfM6ERkBryUiTEA/edit?tab=t.0#heading=h.g8s3tkkthjdk
Step Paper: https://arxiv.org/abs/2507.19427

This is the preliminary implementation of Step3 AFD. It currently only supports the StepMesh connector and the Step3 model. In the future, the community will help expand the connector and add support for the DeepSeek V3 model as mentioned in RFC

The current CUDA Graph implementation for AFD still requires optimization. At present, it involves intrusive modifications to each model and is not compatible with the existing CudaGraphWrapper implementation. Welcome everyone to join the discussion on finding a more elegant solution.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

mergify bot commented Sep 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Oliver-ss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a preliminary implementation of Step3 AFD (Attention FFN Disaggregation), a significant new feature for distributed inference. The changes are extensive, adding new configurations, communication connectors, and modifying core model execution logic. While the implementation lays a solid foundation, I've identified several critical issues related to correctness in distributed settings, maintainability, and configuration that should be addressed. These include potential race conditions or deadlocks due to incorrect distributed logic, CUDA graph caching issues, hardcoded values that limit portability, and significant code duplication. Addressing these points will improve the robustness and usability of this new feature.

Comment on lines +2420 to +2426
factors: list[Any] = [
self.afd_connector,
self.afd_role,
self.num_afd_stages,
self.num_attention_servers,
self.num_ffn_servers,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compute_hash method for AFDConfig omits several fields that could affect the computation graph, such as afd_server_rank and afd_extra_config. If these fields alter the model's execution path, their omission can lead to CUDA graph cache collisions, causing incorrect results or crashes. The afd_extra_config dictionary should be hashed in a deterministic way (e.g., by sorting keys) to ensure a stable hash.

Suggested change
factors: list[Any] = [
self.afd_connector,
self.afd_role,
self.num_afd_stages,
self.num_attention_servers,
self.num_ffn_servers,
]
factors: list[Any] = [
self.afd_connector,
self.afd_role,
self.num_afd_stages,
self.num_attention_servers,
self.num_ffn_servers,
self.afd_server_rank,
]
if self.afd_extra_config:
factors.append(json.dumps(self.afd_extra_config, sort_keys=True))

Comment on lines +94 to +97
def recv_ffn_output(
self,
handle: Any,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of recv_ffn_output in this abstract base class (handle: Any) does not match the implementations in DummyAFDConnector and StepMeshAFDConnector, which use timeout_ms: Optional[float] = None. This violates the Liskov Substitution Principle and will lead to runtime errors. The call sites in step3_text.py also call this method without arguments. The signatures across the base class and all implementations should be consistent.

Suggested change
def recv_ffn_output(
self,
handle: Any,
) -> torch.Tensor:
def recv_ffn_output(
self,
timeout_ms: Optional[float] = None,
) -> torch.Tensor:

Comment on lines +173 to +183
self.scheduler_process = subprocess.Popen(
[
"python",
"-c",
"import torch; import fserver_lib as ps; import os; "
'os.environ["DMLC_ROLE"] = "scheduler"; '
'os.environ["DMLC_INTERFACE"] = "brainpf_bond0"; '
"ps.init(); ps.stop()",
],
env=os.environ.copy(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The network interface DMLC_INTERFACE is hardcoded to "brainpf_bond0" for the scheduler subprocess. This is specific to a particular environment and will cause the scheduler to fail in other environments where this interface does not exist. This should be made configurable or use a more general default like "auto" which is used elsewhere in this file.

            self.scheduler_process = subprocess.Popen(
                [
                    "python",
                    "-c",
                    "import torch; import fserver_lib as ps; import os; "
                    'os.environ["DMLC_ROLE"] = "scheduler"; '
                    "ps.init(); ps.stop()",
                ],
                env=os.environ.copy(),
            )

Comment on lines +218 to +222
def send_attn_output(
self,
hidden_states: torch.Tensor,
metadata: AFDConnectorMetadata,
) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The send_attn_output method is defined in the base class AFDConnectorBase to return a handle of type Any. However, this implementation does not return any value. This violates the interface contract and can lead to NoneType errors if the caller expects a handle. The method should return the event handle created by ps.push_pull.

Comment on lines +504 to +513
stage_num_reqs = stage_end_req - stage_start_req
stage_num_actual_tokens = stage_end_token - stage_start_token
stage_max_seq_len = int(
seq_lens_cpu[stage_start_req:stage_end_req].max())

stage_max_query_len = min(max_query_len, stage_num_actual_tokens)

if stage_num_actual_tokens == 0 or stage_num_reqs == 0:
stage_metadatas.append(None)
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential RuntimeError here. If a stage has no requests (stage_num_reqs == 0), seq_lens_cpu[stage_start_req:stage_end_req] will be an empty tensor. Calling .max() on an empty tensor will raise an error. The check for stage_num_reqs == 0 should be performed before attempting to calculate stage_max_seq_len.

Suggested change
stage_num_reqs = stage_end_req - stage_start_req
stage_num_actual_tokens = stage_end_token - stage_start_token
stage_max_seq_len = int(
seq_lens_cpu[stage_start_req:stage_end_req].max())
stage_max_query_len = min(max_query_len, stage_num_actual_tokens)
if stage_num_actual_tokens == 0 or stage_num_reqs == 0:
stage_metadatas.append(None)
continue
stage_num_reqs = stage_end_req - stage_start_req
stage_num_actual_tokens = stage_end_token - stage_start_token
if stage_num_actual_tokens == 0 or stage_num_reqs == 0:
stage_metadatas.append(None)
continue
stage_max_seq_len = int(
seq_lens_cpu[stage_start_req:stage_end_req].max())
stage_max_query_len = min(max_query_len, stage_num_actual_tokens)

Comment on lines +300 to +312
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size > 1:
# Handle TP case: all-gather tensors from all TP ranks
gathered_hidden_states = tensor_model_parallel_all_gather(
hidden_states, dim=0)
ffn_output = self.model.compute_ffn_output(current_layer_idx,
gathered_hidden_states)

# Extract the output corresponding to current rank
start_idx = hidden_states.shape[
0] * get_tensor_model_parallel_rank()
end_idx = start_idx + hidden_states.shape[0]
rank_ffn_output = ffn_output[start_idx:end_idx, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This logic for handling tensor parallelism appears to be incorrect, similar to the issue in _execute_eager_mode. The tensor_model_parallel_all_gather will unnecessarily replicate the hidden_states tensor across TP ranks, leading to incorrect computations. The FFN layers themselves should manage the sharding. The input hidden_states should be passed directly to self.model.compute_ffn_output.

        # The input hidden_states is replicated on all TP ranks.
        # The FFN layers with tensor parallelism will handle sharding internally.
        rank_ffn_output = self.model.compute_ffn_output(
            current_layer_idx, hidden_states)

Comment on lines +297 to +303
afd_metadata = forward_context.afd_metadata
if afd_metadata is not None:
afd_stage_idx = afd_metadata.afd_stage_idx
if afd_stage_idx < len(attn_metadata):
attn_metadata = attn_metadata[afd_stage_idx]
else:
attn_metadata = None # padding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to extract attn_metadata based on afd_metadata is duplicated in at least five places in this file (here, in the other forward path, maybe_save_kv_layer_to_connector, unified_attention, and unified_attention_with_output). This much code duplication is a significant maintainability risk; a bug fix or logic change would need to be applied in all locations, which is error-prone. I recommend refactoring this logic into a helper function to centralize it and improve code clarity.

Comment on lines +305 to +306
layer_idx=-1, # Extract from comm_id
stage_idx=-1, # Extract from comm_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using hardcoded placeholder values (-1) for layer_idx and stage_idx is risky. The TODO comment indicates this is incomplete. If downstream code does not properly handle these negative indices, it could lead to indexing errors or silent correctness issues. This logic should be fully implemented to extract the correct indices from comm_id before this feature is considered complete.

try:
if len(self.events) > 0:
event, metadata = self.events.popleft()
ps.wait(event, timeout_ms=50000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A hardcoded timeout of 50 seconds is used when waiting for the FFN output. This value may not be suitable for all workloads or environments. For production readiness, this timeout should be configurable, for instance, through the afd_extra_config dictionary in AFDConfig.


if self.profiler:
self.profiler.start()
for _ in range(1000): # FIXME: hardcoded profiler iterations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The number of profiler iterations is hardcoded to 1000. This is inflexible and should be made configurable, for example, through an environment variable or a configuration parameter. The FIXME comment indicates this is a known issue that should be addressed.

Comment on lines +28 to +30
def loader() -> type[AFDConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can directly use vllm.utils.resolve_obj_by_qualname.

slot_mapping=stage_slot_mapping,
use_cascade=False,
common_prefix_len=0,
scheduler_metadata=stage_scheduler_metadata,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if aot_schedule is false, stage_scheduler_metadata and FlashAttentionMetadata can not be created.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configs no longer live in vllm/config/__init__.py

Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--25162.org.readthedocs.build/en/25162/

@fengidri
Copy link

Hello, I'd like to ask: how does StepMesh implement CUDA Graph? During graph replay, Python code isn't executed—so how is the PushPullWorker thread in StepMesh launched?

@Oliver-ss
Copy link
Contributor Author

Oliver-ss commented Oct 10, 2025

Hello, I'd like to ask: how does StepMesh implement CUDA Graph? During graph replay, Python code isn't executed—so how is the PushPullWorker thread in StepMesh launched?

Now stepmesh connector operation is not captured inside the cuda graph and piecewise cuda graph is used.

causal=causal)
return attn_metadata

def _init_stage_buffers(self, vllm_config: VllmConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see:

if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata
)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
attn_metadata_i = attn_group.get_metadata_builder(
ubatch_id=ubid
).build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
, I think this infrastructure can be reused instead of having to add all of this to the FA backend

1. Attn

```
vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be for vllm serve?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants