From 7c90715592bc5d1e0b71aa25242fba459f8295d6 Mon Sep 17 00:00:00 2001 From: zhengranzeng Date: Wed, 25 Oct 2023 23:24:33 +0800 Subject: [PATCH 1/8] Add CodeShell Model. --- Makefile | 3 + README.md | 10 +- .../text_generation_server/models/__init__.py | 22 + .../custom_modeling/flash_shell_modeling.py | 428 ++++++++++++++++++ .../models/flash_shell.py | 78 ++++ server/text_generation_server/utils/layers.py | 3 +- 6 files changed, 538 insertions(+), 6 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_shell_modeling.py create mode 100644 server/text_generation_server/models/flash_shell.py diff --git a/Makefile b/Makefile index 7f534c7ccd7..ea779d1d1b1 100644 --- a/Makefile +++ b/Makefile @@ -48,5 +48,8 @@ run-falcon-7b-instruct: run-falcon-7b-instruct-quantize: text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080 +run-codeshell: + text-generation-launcher --model-id /shd/zzr/code_copilot/training/output_models/codeshellchat-special_codeshellchat-v2_copilot-randomtoken/checkpoint-1300 --port 3030 --max-total-tokens 5000 --max-input-length 4096 --num-shard 1 + clean: rm -rf target aml diff --git a/README.md b/README.md index 339b5db7fda..679e9ffdc51 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ For example, if you want to serve the gated Llama V2 model variants: or with Docker: ```shell -model=meta-llama/Llama-2-7b-chat-hf +model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= @@ -225,7 +225,7 @@ Then run: ```shell BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels -make run-falcon-7b-instruct +make run-codeshell ``` **Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: @@ -241,12 +241,12 @@ the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. Be aware that the official Docker image has them enabled by default. -## Run Falcon +## Run CodeShell ### Run ```shell -make run-falcon-7b-instruct +make run-codeshell ``` ### Quantization @@ -254,7 +254,7 @@ make run-falcon-7b-instruct You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: ```shell -make run-falcon-7b-instruct-quantize +make run-codeshell ``` 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715c5e..f3a41800024 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -51,6 +51,9 @@ from text_generation_server.models.flash_llama import ( FlashLlama, ) + from text_generation_server.models.flash_shell import ( + FlashShell, + ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) @@ -215,6 +218,25 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type in {"shell", "codeshell", "kclgpt"}: + if FLASH_ATTENTION: + return FlashShell( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Shell")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py new file mode 100644 index 00000000000..67bb05064fc --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm + +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + +from text_generation_server.utils.flash_attn import attention +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + FastLayerNorm, + get_linear, +) + + +class ShellConfig(PretrainedConfig): + def __init__( + self, + vocab_size=70144, + hidden_size=4096, + intermediate_size=16384, + num_hidden_layers=42, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=70000, + bos_token_id=70000, + eos_token_id=70000, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +def load_attention(config, prefix, weights): + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.c_attn", + weights=weights, + bias=True, + ) + +def load_row(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + else: + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), process_group=weights.process_group + ) + +class FlashShellAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # self.rotary_emb = PositionRotaryEmbedding.load( + # config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + # print("config.num_key_value_heads", config.num_key_value_heads) + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + print("hidden_states", hidden_states.size()) + qkv = self.query_key_value(hidden_states) + print("qkv", qkv.size(), self.head_size * self.num_heads, 2 * self.head_size * self.num_key_value_heads) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + vllm_cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class ShellMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.c_fc", + weights=weights, + bias=True, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + hidden_states = self.gate_up_proj(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +class FlashShellLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"transformer.h.{layer_id}" + self.self_attn = FlashShellAttention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = ShellMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashShellModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="transformer.wte", weights=weights, reduce=False, + ) + self.layers = nn.ModuleList( + [ + FlashShellLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastLayerNorm.load( + prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashShellForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.config = config + self.model = FlashShellModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="transformer.wte", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_shell.py b/server/text_generation_server/models/flash_shell.py new file mode 100644 index 00000000000..18819a5ea62 --- /dev/null +++ b/server/text_generation_server/models/flash_shell.py @@ -0,0 +1,78 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_shell_modeling import ( + FlashShellForCausalLM, + ShellConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashShell(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = ShellConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=True, + ) + config.num_key_value_heads = config.num_query_groups + config.hidden_act = config.activation_function + config.intermediate_size = config.n_embd + # config.max_position_embeddings = config.n_positions + config.quantize = quantize + config.transpose = config.architectures[0].startswith("GPT2") + config.device = device + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group, aliases={"transformer.wte.weight": ["lm_head.weight"]}) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id) + model = FlashShellForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(FlashShell, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f38f130ef62..2b3250acd1e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -425,7 +425,8 @@ def load_qkv(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) if bias: - raise NotImplementedError("packed_qkv only implemented for baichuan") + bias = weights.get_tensor(f"{prefix}.bias") + # raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None linear = get_linear(weight, bias, config.quantize) From 1dbeca5e9a6950ba6b63cd598780017ab45fb5ff Mon Sep 17 00:00:00 2001 From: ZZR0 <38292503+ZZR0@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:41:44 +0800 Subject: [PATCH 2/8] Update README.md --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 679e9ffdc51..f3fca295e48 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint. - [Llama V2](https://huggingface.co/meta-llama) - [Code Llama](https://huggingface.co/codellama) - [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) +- [CodeShell](https://huggingface.co/WisdomShell/CodeShell-7B) Other architectures are supported on a best effort basis using: @@ -85,10 +86,10 @@ or The easiest way of getting started is using the official Docker container: ```shell -model=tiiuae/falcon-7b-instruct +model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model ``` **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. @@ -155,7 +156,7 @@ model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model +docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model ``` ### A note on Shared Memory (shm) From d3c25962925d745331065c74d67e296e8fcc0e67 Mon Sep 17 00:00:00 2001 From: ZZR0 <38292503+ZZR0@users.noreply.github.com> Date: Thu, 2 Nov 2023 14:05:06 +0800 Subject: [PATCH 3/8] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f3fca295e48..1e6af9c19a8 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ The easiest way of getting started is using the official Docker container: model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model ``` **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. @@ -156,7 +156,7 @@ model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model +docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data zzr0/text-generation-inference:codeshell-1.1.1 --model-id $model ``` ### A note on Shared Memory (shm) From 90b344198086a388c00300905062ae3e023cbe9d Mon Sep 17 00:00:00 2001 From: zhengranzeng Date: Thu, 2 Nov 2023 17:59:27 +0800 Subject: [PATCH 4/8] Fix multi shard bug. --- Makefile | 3 + .../custom_modeling/flash_shell_modeling.py | 72 ++++++++++++++++--- server/text_generation_server/utils/layers.py | 3 +- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index ea779d1d1b1..313b367853e 100644 --- a/Makefile +++ b/Makefile @@ -51,5 +51,8 @@ run-falcon-7b-instruct-quantize: run-codeshell: text-generation-launcher --model-id /shd/zzr/code_copilot/training/output_models/codeshellchat-special_codeshellchat-v2_copilot-randomtoken/checkpoint-1300 --port 3030 --max-total-tokens 5000 --max-input-length 4096 --num-shard 1 +run-llama2: + text-generation-launcher --model-id /shd/zzr/code_copilot/training/original_models/llama2-7b/ --port 3030 --max-total-tokens 5000 --max-input-length 4096 --num-shard 1 + clean: rm -rf target aml diff --git a/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py index 67bb05064fc..fd61bf35a36 100644 --- a/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py @@ -96,13 +96,62 @@ def __init__( **kwargs, ) -def load_attention(config, prefix, weights): - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.c_attn", - weights=weights, - bias=True, - ) +def load_attention( + config, prefix: str, weights, bias: bool, q_size, kv_size +): + if config.quantize == "gptq": + NotImplementedError("Gptq loading with codeshell is not implemented") + else: + return _load_attention( + config, prefix, weights, bias, q_size, kv_size + ) + +def _load_attention( + config, prefix: str, weights, bias: bool, q_size, kv_size +): + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert q_size % world_size == 0 + q_block_size = q_size // world_size + q_start = rank * q_block_size + q_stop = (rank + 1) * q_block_size + + assert kv_size % world_size == 0 + kv_block_size = kv_size // world_size + offset = q_size + k_start = offset + rank * kv_block_size + k_stop = offset + (rank + 1) * kv_block_size + + offset = q_size + kv_size + v_start = offset + rank * kv_block_size + v_stop = offset + (rank + 1) * kv_block_size + + if config.transpose: + q_tensor = slice_[:, q_start:q_stop] + k_tensor = slice_[:, k_start:k_stop] + v_tensor = slice_[:, v_start:v_stop] + else: + q_tensor = slice_[q_start:q_stop] + k_tensor = slice_[k_start:k_stop] + v_tensor = slice_[v_start:v_stop] + + weight = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + q_tensor = slice_[q_start:q_stop] + k_tensor = slice_[k_start:k_stop] + v_tensor = slice_[v_start:v_stop] + + bias = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + else: + bias = None + + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if bias is not None: + bias = bias.to(dtype=weights.dtype).to(device=weights.device) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) def load_row(config, prefix: str, weights, bias: bool): if config.transpose: @@ -151,7 +200,10 @@ def __init__( config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + q_size = config.num_attention_heads * self.head_size + kv_size = config.num_key_value_heads * self.head_size + self.query_key_value = load_attention(config, prefix, weights, + True, q_size, kv_size) self.o_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True @@ -173,9 +225,7 @@ def forward( input_lengths, max_s, ): - print("hidden_states", hidden_states.size()) qkv = self.query_key_value(hidden_states) - print("qkv", qkv.size(), self.head_size * self.num_heads, 2 * self.head_size * self.num_key_value_heads) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -328,7 +378,7 @@ def __init__(self, config, weights): self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights, reduce=False, + prefix="transformer.wte", weights=weights, reduce=True, ) self.layers = nn.ModuleList( [ diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 2b3250acd1e..f38f130ef62 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -425,8 +425,7 @@ def load_qkv(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) if bias: - bias = weights.get_tensor(f"{prefix}.bias") - # raise NotImplementedError("packed_qkv only implemented for baichuan") + raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None linear = get_linear(weight, bias, config.quantize) From 092962a585aebf8c172cda3fe84437dc53c5b336 Mon Sep 17 00:00:00 2001 From: zhengranzeng Date: Sun, 28 Jan 2024 12:05:25 +0800 Subject: [PATCH 5/8] debug --- Makefile | 4 ++-- entrypoint.sh | 20 +++++++++++++++++++ .../text_generation_server/models/__init__.py | 3 +++ test.sh | 7 +++++++ 4 files changed, 32 insertions(+), 2 deletions(-) create mode 100755 entrypoint.sh create mode 100755 test.sh diff --git a/Makefile b/Makefile index 313b367853e..18a4c5ab62b 100644 --- a/Makefile +++ b/Makefile @@ -49,10 +49,10 @@ run-falcon-7b-instruct-quantize: text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080 run-codeshell: - text-generation-launcher --model-id /shd/zzr/code_copilot/training/output_models/codeshellchat-special_codeshellchat-v2_copilot-randomtoken/checkpoint-1300 --port 3030 --max-total-tokens 5000 --max-input-length 4096 --num-shard 1 + text-generation-launcher --model-id /models/codeshell-7b/ --port 9123 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 run-llama2: - text-generation-launcher --model-id /shd/zzr/code_copilot/training/original_models/llama2-7b/ --port 3030 --max-total-tokens 5000 --max-input-length 4096 --num-shard 1 + text-generation-launcher --model-id /models/codellama-7b/ --port 9123 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 clean: rm -rf target aml diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 00000000000..2af647e924d --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# docker run --gpus '"device=0,1"' --shm-size 1g --publish 9123:80 \ +# --volume /shd/zzr/models:/models \ +# --volume /nvme/zzr/text-generation-inference:/usr/src/server \ +# ghcr.nju.edu.cn/huggingface/text-generation-inference:1.4 \ +# --model-id /models/codeshell-7b \ +# --num-shard 2 \ +# --rope-scaling dynamic \ +# --rope-factor 8 \ +# --max-input-length 31000 \ +# --max-total-tokens 32768 \ +# --max-batch-prefill-tokens 31000 \ +# --max-stop-sequences 12 +# curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +# source "$HOME/.cargo/env" +# cd router && cargo install --path . +# BUILD_EXTENSIONS=True make install-server +make run-llama2 +# make run-codeshell \ No newline at end of file diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fee26e4a861..2f181b52766 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -288,6 +288,9 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type in {"shell", "codeshell", "kclgpt"}: + print("model_type", model_type) + print("FLASH_ATTENTION", FLASH_ATTENTION) + print("sharded", sharded) if FLASH_ATTENTION: return FlashShell( model_id, diff --git a/test.sh b/test.sh new file mode 100755 index 00000000000..4f95cc77ded --- /dev/null +++ b/test.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +docker run --entrypoint ./entrypoint.sh \ + --gpus '"device=0,1"' --shm-size 1g --publish 9123:80 \ + --volume /shd/zzr/models:/models \ + --volume /nvme/zzr/text-generation-inference:/usr/src \ + ghcr.nju.edu.cn/huggingface/text-generation-inference:1.4 \ From fa1ba08501b5a7368a376d1a5f13b2f846e03a0e Mon Sep 17 00:00:00 2001 From: zhengranzeng Date: Mon, 29 Jan 2024 11:29:47 +0800 Subject: [PATCH 6/8] debug --- Makefile | 4 ++-- entrypoint.sh | 1 + test.sh | 20 +++++++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 18a4c5ab62b..75a4a9f3059 100644 --- a/Makefile +++ b/Makefile @@ -49,10 +49,10 @@ run-falcon-7b-instruct-quantize: text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080 run-codeshell: - text-generation-launcher --model-id /models/codeshell-7b/ --port 9123 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 + text-generation-launcher --model-id /models/codeshell-7b/ --port 80 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 run-llama2: - text-generation-launcher --model-id /models/codellama-7b/ --port 9123 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 + text-generation-launcher --model-id /models/codellama-7b/ --port 80 --num-shard 1 --rope-scaling dynamic --rope-factor 8 --max-input-length 31000 --max-total-tokens 32768 --max-batch-prefill-tokens 31000 clean: rm -rf target aml diff --git a/entrypoint.sh b/entrypoint.sh index 2af647e924d..455ba1ce5b7 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -15,6 +15,7 @@ # curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # source "$HOME/.cargo/env" # cd router && cargo install --path . +make clean # BUILD_EXTENSIONS=True make install-server make run-llama2 # make run-codeshell \ No newline at end of file diff --git a/test.sh b/test.sh index 4f95cc77ded..468ee364a56 100755 --- a/test.sh +++ b/test.sh @@ -1,7 +1,21 @@ #!/bin/bash -docker run --entrypoint ./entrypoint.sh \ - --gpus '"device=0,1"' --shm-size 1g --publish 9123:80 \ +# docker run --entrypoint ./entrypoint.sh \ +# --gpus '"device=0,1"' --shm-size 1g --publish 9123:80 \ +# --volume /shd/zzr/models:/models \ +# --volume /nvme/zzr/text-generation-inference:/usr/src \ +# ghcr.nju.edu.cn/huggingface/text-generation-inference:1.4 \ + + +docker run --gpus '"device=0,1"' --shm-size 1g --publish 9123:80 \ + --volume /nvme/zzr/text-generation-inference/server/text_generation_server:/opt/conda/lib/python3.10/site-packages/text_generation_server \ --volume /shd/zzr/models:/models \ - --volume /nvme/zzr/text-generation-inference:/usr/src \ ghcr.nju.edu.cn/huggingface/text-generation-inference:1.4 \ + --model-id /models/codellama-7b \ + --num-shard 2 \ + --rope-scaling dynamic \ + --rope-factor 8 \ + --max-input-length 31000 \ + --max-total-tokens 32768 \ + --max-batch-prefill-tokens 31000 \ + --max-stop-sequences 12 \ No newline at end of file From 4e9b31462350e5085e2047624d3340c414aef69c Mon Sep 17 00:00:00 2001 From: zhengranzeng Date: Wed, 31 Jan 2024 18:46:39 +0800 Subject: [PATCH 7/8] update shell to tgi 1.4.0 --- .../custom_modeling/flash_shell_modeling.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py index fd61bf35a36..92b84d26d7c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_shell_modeling.py @@ -26,22 +26,15 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import dropout_layer_norm - -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - -from text_generation_server.utils.flash_attn import attention +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, TensorParallelHead, - FastLayerNorm, get_linear, + FastLayerNorm, ) @@ -180,11 +173,11 @@ def __init__( self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - # self.rotary_emb = PositionRotaryEmbedding.load( - # config=config, prefix=f"{prefix}.rotary_emb", weights=weights - # ) self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size**-0.5 @@ -195,7 +188,6 @@ def __init__( f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - # print("config.num_key_value_heads", config.num_key_value_heads) self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) @@ -236,10 +228,9 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -249,7 +240,7 @@ def forward( # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -260,9 +251,7 @@ def forward( ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -271,7 +260,6 @@ def forward( self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) From b63b9482400be8e4db1fc75c3755369ac5db80f9 Mon Sep 17 00:00:00 2001 From: ZZR0 <38292503+ZZR0@users.noreply.github.com> Date: Tue, 5 Mar 2024 23:03:24 +0800 Subject: [PATCH 8/8] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5804817d49..b22b5c84ef3 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data zzr0/text-generation-inference:codeshell-1.4 --model-id $model +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data zzr0/text-generation-inference:shell-1.4.0 --model-id $model ``` **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. @@ -138,7 +138,7 @@ model=WisdomShell/CodeShell-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= -docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data zzr0/text-generation-inference:codeshell-1.4 --model-id $model +docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data zzr0/text-generation-inference:shell-1.4.0 --model-id $model ``` ### A note on Shared Memory (shm)