diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2ce9d60f08d8..997555709132 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,5 +1,5 @@ """Multi-head attention.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -13,6 +13,7 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip +from vllm.model_executor.layers.rotary_embedding import get_rope _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -192,7 +193,6 @@ def forward( # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) - def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fb519b3c0cf9..b38b54d0bca2 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -41,7 +41,8 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "YiForCausalLM": ("yi", "YiForCausalLM") + "YiForCausalLM": ("yi", "YiForCausalLM"), + "CodeShellForCausalLM": ("codeshell", "CodeShellForCausalLM") } # Models not supported by ROCm. diff --git a/vllm/model_executor/models/codeshell.py b/vllm/model_executor/models/codeshell.py new file mode 100644 index 000000000000..0096895ccda2 --- /dev/null +++ b/vllm/model_executor/models/codeshell.py @@ -0,0 +1,313 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only CodeShell model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +import math +from torch import nn +# from transformers import CodeShellConfig +from vllm.transformers_utils.configs.codeshell import CodeShellConfig + + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.model_executor.layers.rotary_embedding import get_rope + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class CodeShellAttention(nn.Module): + + def __init__( + self, + config: CodeShellConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + # print("Total num heads: ", total_num_heads) + self.tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + + self.group_query_attention = config.group_query_attention + + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // total_num_heads + + self.scaling = 1 / math.sqrt(self.head_dim) + self.rope_scaling = config.rope_scaling + self.rope_theta = 10000 + self.max_position_embeddings = config.max_position_embeddings + + self.total_num_kv_heads = config.num_query_groups + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tensor_model_parallel_world_size) + self.kv_dim = self.head_dim * self.num_kv_heads + + + self.c_attn = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) + + # self.c_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=None, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def butterfly_transform(self, obj: torch.Tensor) -> torch.Tensor: + query_group = self.num_kv_heads + n_tokens = obj.shape[0] + + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], + dim=-1, + ) + q, k = self.rotary_emb(positions, q, k) + + key_cache, value_cache = kv_cache + attn_output = self.attn(query=q, key=k, value=v, key_cache=key_cache, value_cache=value_cache, + input_metadata=input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class CodeShellMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: CodeShellConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class CodeShellBlock(nn.Module): + + def __init__( + self, + config: CodeShellConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = CodeShellAttention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = CodeShellMLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + positions=positions, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class CodeShellModel(nn.Module): + + def __init__( + self, + config: CodeShellConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.h = nn.ModuleList([ + CodeShellBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states=hidden_states, positions=position_ids, kv_cache=kv_caches[i], input_metadata=input_metadata) + + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class CodeShellForCausalLM(nn.Module): + + def __init__( + self, + config: CodeShellConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = CodeShellModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + param = params_dict['lm_head_weight'] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + continue + if ".attn.bias" in name: + continue + if "rotary_emb" in name: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/codeshell.py b/vllm/transformers_utils/configs/codeshell.py new file mode 100644 index 000000000000..6785344ba3f4 --- /dev/null +++ b/vllm/transformers_utils/configs/codeshell.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2023 WisdomShell Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on Bigcode's GPTBigCode configuration. It has been modified from +# its original forms to accommodate minor architectural differences compared to +# GPTBigCode Configuration that trained the model. + +# Copyright 2023 The BigCode team and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Shell configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class CodeShellConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`CodeShellModel`]. It is used to instantiate a + CodeShell model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ShellModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", + "gelu_pytorch_tanh"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to call the fused softmax in float32. + scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to scale the attention softmax in float32. + attention_type (`bool`, *optional*, defaults to `True`): + Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). + Example: + + ```python + >>> from configuration_codeshell import CodeShellConfig + >>> from modeling_codeshell import CodeShellForCausalLM + + >>> # Initializing a CodeShell configuration + >>> configuration = CodeShellConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = CodeShellForCausalLM(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "codeshell" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=70144, + n_positions=8192, + n_embd=4096, + n_layer=42, + n_head=32, + n_inner=None, + activation_function="gelu_pytorch_tanh", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=70000, + eos_token_id=70000, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + group_query_attention=True, + num_query_groups=1, + position_embedding_type="learned_absolute", + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.group_query_attention = group_query_attention + self.num_query_groups = num_query_groups + self.position_embedding_type = position_embedding_type + self.rope_scaling = rope_scaling + assert self.position_embedding_type in [ + "learned_absolute", "rope" + ], "position_embedding_type must be one of ['learned_absolute', 'rope']" + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)