diff --git a/DraftRetriever/src/lib.rs b/DraftRetriever/src/lib.rs index 024295f..96f796b 100644 --- a/DraftRetriever/src/lib.rs +++ b/DraftRetriever/src/lib.rs @@ -1,4 +1,4 @@ -// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch; +// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch; // The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124 use ahash::AHashSet; use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian}; @@ -67,6 +67,7 @@ impl Writer { let index_file = File::create(index_file_path)?; let index_file = BufWriter::new(index_file); + // max_chunk_len can be whatever we want it to be, but the draftretriever Reader() seems to work fastest when we choose something large (i.e. 2e27) let max_chunk_len = max_chunk_len.unwrap_or(512 * 1024 * 1024); let vocab_size = vocab_size.unwrap_or(35000); @@ -111,10 +112,13 @@ impl Writer { return Ok(()); } - self.index_file.write_u32::((self.buffer.len() * 2) as u32)?; + self.index_file.write_u32::((self.buffer.len() * 4) as u32)?; // self.buffer.len() is the length of the buffer (in # of integers). This is variable because sometimes we dump_data() early, its not always self.buffer.capacity(). + // * 4 because this value will actually tell us how much space is needed for this buffer in file, and we store each as 4 bytes + // For larger vocabularies (ie > 65,535), we should write the integers as i32 instead of u16 + // Keeping i32 instead of u32 so negative values can be used as pad tokens (i.e. pad_path(path, max_length, -2)) for &item in &self.buffer { - self.index_file.write_u16::(item as u16)?; + self.index_file.write_i32::(item as i32)?; } let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size); @@ -188,8 +192,9 @@ impl Reader { let mut data: Vec = Vec::new(); - for i in (0..data_u8.len()).step_by(2) { - let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32; + // Step by 4 to read in each 4-byte int (i32) from index file + for i in (0..data_u8.len()).step_by(4) { + let int = LittleEndian::read_i32(&data_u8[i..i+4]) as i32; data.push(int); } @@ -259,7 +264,7 @@ impl Reader { if start_of_indices.is_none() { return; } - + // this binary search finds the end of the matching suffixes let mut right_anchor = sub_index.suffixes_file_end - 4; while left_anchor <= right_anchor { @@ -300,7 +305,7 @@ impl Reader { let data_index = LittleEndian::read_i32(suffix); if matches_ranges.insert(data_index) { let sub_string_plus = &sub_index.data[data_index as usize + substring_i32.len() ..std::cmp::min(data_index as usize + substring_i32.len() + long as usize, sub_index.data.len())]; - + local_results.push(sub_string_plus.to_vec()); cnt += 1; if cnt >= k as usize { @@ -328,7 +333,7 @@ impl Reader { *counter += 1; } } - + let choices = choices.unwrap_or(64); // The items in the heap must be a Trie. let mut heap = BinaryHeap::new(); @@ -348,7 +353,7 @@ impl Reader { let verified: Vec<_> = verified.into_iter().collect(); // Because multiple nodes in the Trie may have same weights around the threshold, the number of draft tokens may exceed choices - // We roughly cut nodes to be less than choices in most cases. + // We roughly cut nodes to be less than choices in most cases. let paths = cut_to_choices(verified, choices); let (draft_choices, max_branch) = get_draft_choices(paths.clone()); @@ -562,4 +567,3 @@ fn draftretriever( Ok(()) } - diff --git a/datastore/get_datastore_chat.py b/datastore/get_datastore_chat.py index c86c3a4..e97b5e2 100644 --- a/datastore/get_datastore_chat.py +++ b/datastore/get_datastore_chat.py @@ -29,7 +29,7 @@ writer = draftretriever.Writer( index_file_path=datastore_path, max_chunk_len=512*1024*1024, - vocab_size=tokenizer.vocab_size, + vocab_size=tokenizer.vocab_size + len(tokenizer.get_added_vocab()), ) if args.large_datastore: dataset = load_dataset('stingning/ultrachat', split='train') diff --git a/datastore/get_datastore_code.py b/datastore/get_datastore_code.py index b8b64b7..a49b5f2 100644 --- a/datastore/get_datastore_code.py +++ b/datastore/get_datastore_code.py @@ -41,7 +41,7 @@ writer = draftretriever.Writer( index_file_path=datastore_path, max_chunk_len=512 * 1024 * 1024, - vocab_size=tokenizer.vocab_size, + vocab_size=tokenizer.vocab_size + len(tokenizer.get_added_vocab()), ) total_length = len(dataset) @@ -51,4 +51,4 @@ token_list = tokenizer.encode(sample['content']) writer.add_entry(token_list) -writer.finalize() \ No newline at end of file +writer.finalize() diff --git a/rest/model/modeling_llama_kv.py b/rest/model/modeling_llama_kv.py index 6a83d0c..b733fad 100644 --- a/rest/model/modeling_llama_kv.py +++ b/rest/model/modeling_llama_kv.py @@ -1,9 +1,8 @@ -# Source: https://github.com/huggingface/transformers/blob/v4.31-release/src/transformers/models/llama/modeling_llama.py +# Source: https://github.com/huggingface/transformers/blob/v4.34-release/src/transformers/models/llama/modeling_llama.py # Modifications are denoted by the symbol: [MODIFIED] # There are mainly two modifications: # 1. Using preallocated GPU memory for KVCache -# 2. Modifying attention mask for integration with Rest -# adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/modeling_llama_kv.py +# 2. Modifying attention mask for integration with Medusa """ PyTorch LLaMA model.""" import math @@ -17,19 +16,23 @@ # [MODIFIED] Import from transformer library from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) -from transformers import LlamaConfig +from transformers.models.llama.configuration_llama import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa logger = logging.get_logger(__name__) @@ -37,24 +40,24 @@ _CONFIG_FOR_DOC = "LlamaConfig" +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ - Create a causal mask for bi-directional self-attention. - - Args: - input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len). - dtype (torch.dtype): The data type of the mask. - device (torch.device): The device on which the mask will be placed. - past_key_values_length (int, optional): The length of past key values. Default is 0. - - Returns: - torch.Tensor: The causal mask tensor. + Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) @@ -63,32 +66,14 @@ def _make_causal_mask( mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ - Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - - Args: - mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`. - dtype (torch.dtype): The data type of the mask. - tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length. - - Returns: - torch.Tensor: The expanded mask tensor. + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len @@ -97,38 +82,19 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -import torch.nn as nn -import torch - class LlamaRMSNorm(nn.Module): - """ - LlamaRMSNorm is equivalent to T5LayerNorm. - - Args: - hidden_size (int): The size of the hidden states. - eps (float, optional): A small value to prevent division by zero. Default is 1e-6. - """ - def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): - """ - Apply LlamaRMSNorm to the input hidden states. - - Args: - hidden_states (torch.Tensor): Input hidden states. - - Returns: - torch.Tensor: The normalized and scaled hidden states. - """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -136,247 +102,156 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -class LlamaRotaryEmbedding(nn.Module): - """ - Llama Rotary Positional Embedding Module. +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) - Args: - dim (int): The dimension of the embedding. - max_position_embeddings (int, optional): The maximum position for embeddings. Default is 2048. - base (int, optional): The base value for rotational encoding. Default is 10000. - device (str, optional): The device on which the computation will be performed. Default is None. - """ - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - def _set_cos_sin_cache(self, seq_len, device, dtype): + def _dynamic_frequency_update(self, position_ids, device): """ - Set the cosine and sine cache for positional embeddings. - - Args: - seq_len (int): The sequence length. - device (str): The device on which the cache tensors will be stored. - dtype: The data type of the cache tensors. + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False - ) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len - def forward(self, x, seq_len=None): - """ - Forward pass of the LlamaRotaryEmbedding module. + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) - Args: - x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size]. - seq_len (int): The sequence length. If greater than the cached length, the cache will be updated. + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() - Returns: - tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim]. - """ - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """ - LlamaRotaryEmbedding extended with linear scaling. + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - This class adds linear scaling to LlamaRotaryEmbedding. Credits to the Reddit user /u/kaiokendev. - - Args: - dim (int): The dimension of the embedding. - max_position_embeddings (int, optional): The maximum number of position embeddings. Default is 2048. - base (int, optional): The base value for the rotational embeddings. Default is 10000. - device (str or torch.device, optional): The device where the embeddings should be stored. Default is None. - scaling_factor (float, optional): The scaling factor for the embeddings. Default is 1.0. - """ - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - """ - Set the cosine and sine cache for the rotary embeddings. - - Args: - seq_len (int): The sequence length. - device (str or torch.device): The device where the cache should be stored. - dtype: The data type for the cache. - """ - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """ - LlamaRotaryEmbedding extended with Dynamic NTK scaling. - - Credits to the Reddit users /u/bloc97 and /u/emozilla. - """ - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - """ - Initialize the LlamaDynamicNTKScalingRotaryEmbedding. - - Args: - dim (int): The dimensionality of the embedding. - max_position_embeddings (int, optional): Maximum number of position embeddings. Default is 2048. - base (int, optional): Base value for scaling calculations. Default is 10000. - device: The device to place tensors on. If None, uses the default device. - scaling_factor (float, optional): Scaling factor for NTK scaling. Default is 1.0. - """ - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def _set_cos_sin_cache(self, seq_len, device, dtype): - """ - Set the cached values for cosine and sine. - - Args: - seq_len (int): The sequence length. - device: The device to place tensors on. - dtype: The data type of tensors. - """ - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False - ) def rotate_half(x): - """ - Rotates half the hidden dimensions of the input. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Tensor with half of its hidden dimensions rotated. - """ + """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - """ - Apply rotary position embeddings to query and key tensors. - Args: - q (torch.Tensor): Query tensor. - k (torch.Tensor): Key tensor. - cos (torch.Tensor): Cosine values. - sin (torch.Tensor): Sine values. - position_ids (torch.Tensor): Position IDs. +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: - torch.Tensor: Query and key tensors with rotary position embeddings applied. + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.squeeze(1).squeeze(0) - sin = sin.squeeze(1).squeeze(0) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LlamaMLP(nn.Module): - """ - LlamaMLP is a multi-layer perceptron module used in the Llama model. - - Args: - config: The configuration for the MLP. - - Attributes: - pretraining_tp (int): The pretraining time periods. - hidden_size (int): The size of the hidden layer. - intermediate_size (int): The size of the intermediate layer. - gate_proj (nn.Linear): The linear projection for gating. - up_proj (nn.Linear): The linear projection for the up projection. - down_proj (nn.Linear): The linear projection for the down projection. - act_fn: The activation function. - - """ - def __init__(self, config): super().__init__() self.config = config @@ -388,34 +263,20 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - """ - Forward pass of the MLP. - - Args: - x: Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp + slice = self.intermediate_size // self.config.pretraining_tp gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], - dim=-1, - ) - up_proj = torch.cat( - [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], - dim=-1, + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.pretraining_tp) + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) ] down_proj = sum(down_proj) else: @@ -426,42 +287,18 @@ def forward(self, x): def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ - Repeat key and value tensors n times along the specified dimension. - - Args: - hidden_states (torch.Tensor): Input tensor with shape (batch, num_key_value_heads, seqlen, head_dim). - n_rep (int): Number of times to repeat. - - Returns: - torch.Tensor: Repeated tensor with shape (batch, num_key_value_heads * n_rep, seqlen, head_dim). + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class LlamaAttention(nn.Module): - """ - LlamaAttention is a multi-headed attention module based on the 'Attention Is All You Need' paper. - - Args: - config (LlamaConfig): Configuration for the attention module. - - Attributes: - config (LlamaConfig): Configuration for the attention module. - hidden_size (int): The size of the hidden layer. - num_heads (int): The number of attention heads. - head_dim (int): The dimension of each attention head. - num_key_value_heads (int): The number of key-value attention heads. - num_key_value_groups (int): The number of key-value groups. - pretraining_tp (int): The pretraining time periods. - max_position_embeddings (int): The maximum position embeddings. - - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): super().__init__() @@ -473,57 +310,23 @@ def __init__(self, config: LlamaConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta - + if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self._init_rope() def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -533,35 +336,25 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.pretraining_tp) - ] + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.pretraining_tp) - ] + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.pretraining_tp) - ] + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: @@ -569,23 +362,19 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # change: repeat kv states BEFORE concatenating to past_key_value; otherwise, shape mismatch + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) # [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization # past_key_value is utilized to leverage previously computed key and value states. @@ -596,13 +385,7 @@ def forward( # Reset past_key_value to avoid return past_key_value. past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -618,9 +401,7 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -630,21 +411,13 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.pretraining_tp, dim=2 - ) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - [ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ] - ) + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) @@ -654,30 +427,196 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): +class LlamaFlashAttention2(LlamaAttention): """ - LlamaDecoderLayer represents a single layer of the Llama decoder. - - Args: - config (LlamaConfig): Configuration for the decoder layer. - - Attributes: - hidden_size (int): The size of the hidden layer. - self_attn (LlamaAttention): Multi-headed self-attention module. - mlp (LlamaMLP): Multi-layer perceptron module. - input_layernorm (LlamaRMSNorm): Layer normalization for input. - post_attention_layernorm (LlamaRMSNorm): Layer normalization after self-attention. + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. """ + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -687,29 +626,20 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ - Forward pass for the LlamaDecoderLayer. - Args: - hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`. - attention_mask (torch.FloatTensor, optional): Attention mask of size + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (torch.LongTensor, optional): Positional IDs tensor. - past_key_value (Tuple[torch.FloatTensor], optional): Cached past key and value projection states. - output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. - use_cache (bool, optional): If set to `True`, `past_key_values` key-value states are returned and can be - used to speed up decoding. - - Returns: - Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing: - - hidden_states (torch.FloatTensor): Output tensor. - - self_attn_weights (Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]): Self-attention weights if - `output_attentions` is `True`. - - present_key_value (Optional[Tuple[torch.FloatTensor]]): Cached key and value projection states if - `use_cache` is `True`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states @@ -724,6 +654,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = residual + hidden_states @@ -771,6 +702,7 @@ class LlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_range @@ -809,7 +741,7 @@ def _set_gradient_checkpointing(self, module, value=False): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -831,9 +763,9 @@ def _set_gradient_checkpointing(self, module, value=False): Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -869,12 +801,8 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -888,30 +816,25 @@ def set_input_embeddings(self, value): self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, - # inputs_embeds.dtype, - torch.float32, # [MODIFIED] force to cast to float32 + inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) # [MODIFIED] add draft mask @@ -937,35 +860,23 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) + raise ValueError("You have to specify either input_ids or inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 @@ -977,10 +888,7 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -991,16 +899,21 @@ def forward( # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + # [MODIFIED] # not sure if i need this one + self.attention_mask = attention_mask + self.position_ids = position_ids hidden_states = inputs_embeds @@ -1020,25 +933,19 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) + past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids ) else: layer_outputs = decoder_layer( @@ -1048,6 +955,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = layer_outputs[0] @@ -1066,11 +974,7 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1085,7 +989,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1111,9 +1014,7 @@ def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1153,19 +1054,11 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1181,14 +1074,9 @@ def forward( ) hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split( - self.vocab_size // self.pretraining_tp, dim=0 - ) - logits = [ - F.linear(hidden_states, lm_head_slices[i]) - for i in range(self.pretraining_tp) - ] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) @@ -1220,12 +1108,7 @@ def forward( ) def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] @@ -1259,10 +1142,7 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @@ -1318,9 +1198,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, @@ -1342,22 +1220,18 @@ def forward( batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = ( - torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - ).to(logits.device) + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) else: sequence_lengths = -1 - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: @@ -1365,9 +1239,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1380,9 +1252,7 @@ def forward( loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct( - pooled_logits.view(-1, self.num_labels), labels.view(-1) - ) + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) @@ -1396,4 +1266,4 @@ def forward( past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) + ) \ No newline at end of file