Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
set_forward_context,AFDMetadata)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
Expand Down Expand Up @@ -71,6 +71,7 @@ def set_ascend_forward_context(
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
afd_metadata: Optional[AFDMetadata] = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -84,6 +85,7 @@ def set_ascend_forward_context(
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
afd_metadata = afd_metadata
):
forward_context = get_forward_context()

Expand Down
153 changes: 103 additions & 50 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)

tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -62,6 +64,7 @@
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer)
from typing import Any, Optional, Union
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE


Expand Down Expand Up @@ -439,69 +442,83 @@ def __init__(self,
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.enable_afd = vllm_config.additional_config.get(
"enable_afd", False)

afd_config = vllm_config.afd_config
if afd_config:
self.role = afd_config.afd_role
else:
self.role = None
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if ascend_config.use_sfa:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
ascend_config = get_ascend_config()

if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
if self.role is None or self.role == "attention":
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if ascend_config.use_sfa:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
parallel_config=parallel_config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if self.mlp.gate.e_score_correction_bias is not None:
self.mlp.gate.e_score_correction_bias.data = (
self.mlp.gate.e_score_correction_bias.data.to(
dtype=torch.get_default_dtype()))
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
prefix=f"{prefix}.self_attn",
)

if self.role is None or self.role == "ffn":
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if self.mlp.gate.e_score_correction_bias is not None:
self.mlp.gate.e_score_correction_bias.data = (
self.mlp.gate.e_score_correction_bias.data.to(
dtype=torch.get_default_dtype()))
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp


class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -510,6 +527,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.enable_afd = vllm_config.additional_config.get(
"enable_afd", False)
self.afd_config = vllm_config.afd_config
if self.afd_config:
self.role = self.afd_config.afd_role
else:
self.role = None

# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
Expand Down Expand Up @@ -537,7 +561,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights: list[Any] = []

# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
Expand All @@ -550,11 +574,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
continue

assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
if (self.afd_config is None or self.role == "ffn") and isinstance(
layer.mlp, DeepseekV2MoE):
#if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)

if self.role == "attention":
return
if example_moe is None:
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")

Expand All @@ -564,14 +592,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts


# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
Expand All @@ -584,7 +612,8 @@ def load_weights(self, weights: Iterable[tuple[str,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
num_experts=self.config.n_routed_experts,
)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
Expand All @@ -594,6 +623,8 @@ def load_weights(self, weights: Iterable[tuple[str,
if "module" in name:
continue

if self.role == "attention" and self.is_moe_weight(name):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
Expand All @@ -611,6 +642,7 @@ def load_weights(self, weights: Iterable[tuple[str,
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
Expand All @@ -627,12 +659,16 @@ def load_weights(self, weights: Iterable[tuple[str,
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

if self.role is not None and self.role == "attention":
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]

weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
Expand All @@ -642,6 +678,11 @@ def load_weights(self, weights: Iterable[tuple[str,
return_success=False)
break
else:
if self.role == "ffn" and not self.is_moe_weight(
name) and not self.is_common_weight(name):
continue


# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
Expand All @@ -659,11 +700,23 @@ def load_weights(self, weights: Iterable[tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

return loaded_params

def is_moe_weight(self,name):
if "shared_experts" in name or "experts" in name or "gate" in name \
or "up" in name or "down" in name:
return True
return False

def is_common_weight(self,name):
if "lm_head" in name or "model.norm.weight" in name or "embed_tokens" in name \
or "input_layernorm" in name or "post_attention_layernorm" in name:
return True
return False


class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass


DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__
Loading
Loading