Skip to content

Conversation

chopper0126
Copy link

@chopper0126 chopper0126 commented Oct 14, 2025

What this PR does / why we need it?

This PR corresponds to the RFC vllm-project#22799 and a follow-up PR of vllm-project#25162 and Oliver-ss/vllm#2

This is the preliminary implementation of DeepSeek V2 Lite AFD for Ascend. It currently only supports the P2P connector and the DeepSeek V2 Lite model.

  • AFD for the DeepSeek V2 Lite model as well as a p2p connector for A2E/E2A communication on Asend.
  • extend the metadata of afd connector so that AFD can work with different hardware (GPU, NPU and more).
  • online serving requests

Later, we are going to support the following features:

  • TBO (Triple Batch Overlap) based on DBO extension
  • enable graph mode
  • offline serving request in a batch manner
  • multi-node support for full deepseek-V3/R1 models on GPU/NPU.

How was this patch tested?

use the following script for testing:
online_attn.sh

export ASCEND_RT_VISIBLE_DEVICES=4,5
vllm serve /home/data/DeepSeek-V2-Lite \
    --tensor-parallel-size 2 \
    --enable_expert_parallel \
    --enforce_eager          \
    --afd-config \
    '{"afd_connector":"p2pconnector", "afd_role": "attention", "num_afd_stages":"1","afd_extra_config":{"afd_size":"2A2F"}}'

ffn.sh

export ASCEND_RT_VISIBLE_DEVICES=6,7

python -m vllm.entrypoints.afd_ffn_server /home/data/DeepSeek-V2-Lite \
        --tensor-parallel-size 2 \
        --enable_expert_parallel \
        --enforce_eager          \
        --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"1", "afd_role": "ffn", "afd_extra_config":{"afd_size":"2A2F"}}'

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a preliminary implementation of Attention/FFN Decoupling (AFD) for DeepSeek V3 models on Ascend NPUs. The changes are extensive, adding new FFN workers, model runners, and communication utilities, and modifying the core model and forward context to support the decoupled architecture. While the implementation lays the groundwork for AFD, there are several critical issues that need to be addressed, including hardcoded network configurations and a syntax error due to duplicated fields. Additionally, there are high-severity issues related to performance and best practices, such as an inefficient tensor parallelism implementation and setting environment variables within the code. These issues will need to be resolved to make the implementation robust, configurable, and performant.

ffn.py Outdated
Comment on lines 116 to 117
#TODO:remove hard code
init_method = 'tcp://127.0.0.1:29505'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The init_method for process group initialization is hardcoded with a local IP address and port. This is also the case on lines 94 and 101. This makes the script inflexible and difficult to use in a real distributed environment. These values should be parameterized or read from a configuration file to allow for different network setups.

Comment on lines 40 to 46
new_default_group = init_process_group(
init_method='tcp://127.0.0.1:29500',
backend='gloo',
rank=rank,
world_size=world_size,
group_name="new_hccl"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The init_method for process group initialization is hardcoded with a local IP address and port. This makes this function inflexible and difficult to use in different distributed environments. This value should be passed as an argument or read from a configuration.

Comment on lines 68 to 72
mm_features: Optional[list[MultiModalFeatureSpec]] = None
# for back-compatibility, will be removed in next major release
mm_kwargs: Optional[list[MultiModalKwargsItem]] = None
mm_positions: Optional[list[PlaceholderRange]] = None
mm_hashes: Optional[list[PlaceholderRange]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These fields (mm_features, mm_kwargs, mm_positions, mm_hashes) are duplicated from lines 62-66. This will cause a SyntaxError at runtime. Please remove the duplicated block.

ffn.py Outdated
from vllm_ascend.distributed.afd_communicators import send_object,recv_object,FFNNeedForwardData

import os
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting environment variables within a library or script is generally considered bad practice as it can have unintended side effects on other parts of the application or other libraries. It's better to have the user set this environment variable before running the application. Please move this to the documentation or a startup script.

Comment on lines +1970 to +1980
if afd_metadata:
# Padding for AFD
num_input_tokens = num_input_tokens
(num_pad_afd, afd_tokens_start_loc,
afd_tokens_lens) = self.get_afd_padding(
afd_metadata.afd_tokens_start_loc,
afd_metadata.afd_tokens_lens)
afd_metadata.afd_tokens_start_loc = afd_tokens_start_loc
afd_metadata.afd_tokens_lens = afd_tokens_lens
num_input_tokens += num_pad_afd
num_tokens_across_dp = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for padding for AFD seems to be duplicated. A similar block of code exists in _prepare_inputs (lines 1542-1551). This could be refactored into a helper function to avoid code duplication and improve maintainability.

Comment on lines +176 to +193
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size > 1:
# All-gather hidden states from all TP ranks
gathered_hidden_states = tensor_model_parallel_all_gather(
hidden_states, dim=0)
ffn_output = self.model.compute_ffn_output(current_layer_idx,
gathered_hidden_states)
# Extract the output corresponding to current rank
start_idx = hidden_states.shape[
0] * get_tensor_model_parallel_rank()
end_idx = start_idx + hidden_states.shape[0]
rank_ffn_output = ffn_output[start_idx:end_idx, :]
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
current_layer_idx, hidden_states)

return rank_ffn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling tensor parallelism in _execute_eager_mode seems inefficient. It performs an all_gather on the input hidden_states, then computes the FFN output (which likely involves an all_reduce in the final RowParallelLinear layer), and finally manually slices the output. This results in redundant communication (all_gather followed by all_reduce). A more efficient approach would be to use a reduce_scatter operation in the final layer of the FFN computation, which would directly produce the sliced output for each rank. This would avoid the unnecessary all_gather and the manual slicing. Since this is a hot path, this inefficiency could significantly impact performance.

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@chopper0126 chopper0126 force-pushed the ascendmain-1011 branch 3 times, most recently from 422cc62 to 1ef3c29 Compare October 15, 2025 02:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant