From 919727b4d7f5540db430135641d4b6ad2cbc729f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 10 Jun 2023 08:09:29 +0900 Subject: [PATCH 1/6] Refactor landmark attention patch --- .../monkeypatch/llama_landmark_attn.py | 9 ++++++++ src/axolotl/utils/models.py | 23 +++++++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 18e913f09..1a130f755 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id): ret.extend(x[prev_idx:]) # drop attention_mask return {"input_ids": ret} + + +def patch_llama_with_landmark_attn(): + import transformers + + transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM + transformers.models.llama.modeling_llama.LlamaModel = LlamaModel + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb363952c..b84597076 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -19,15 +19,6 @@ from transformers import ( # noqa: F401 LlamaConfig, ) -try: - from transformers import ( # pylint: disable=unused-import # noqa: F401 - LlamaForCausalLM, - ) -except ImportError: - logging.warning( - "This version of transformers does not support Llama. Consider upgrading." - ) - from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: @@ -118,14 +109,15 @@ def load_model( logging.info("patching with sdp attention") hijack_llama_sdp_attention() elif cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811 + from axolotl.monkeypatch.llama_landmark_attn import ( MEM_TOKEN, - LlamaForCausalLM, + patch_llama_with_landmark_attn, ) logging.info("patching with landmark attention") + patch_llama_with_landmark_attn() - # TODO: Check if this would overwrite previous additional_special_tokens + # Note: This might overwrite previous additional_special_tokens tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) if cfg.is_llama_derived_model and cfg.xpos_rope: @@ -211,6 +203,13 @@ def load_model( ) load_in_8bit = False elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals(): + try: + from transformers import LlamaForCausalLM + except ImportError: + logging.warning( + "This version of transformers does not support Llama. Consider upgrading." + ) + config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( base_model, From e285e24f7f4e1718c0f26c93eb7eec8e669ae97e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 10:52:12 +0900 Subject: [PATCH 2/6] Address PR suggestion Co-authored-by: Wing Lian --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b84597076..b3a5eeb60 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -202,7 +202,7 @@ def load_model( else True, ) load_in_8bit = False - elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals(): + elif cfg.is_llama_derived_model: try: from transformers import LlamaForCausalLM except ImportError: From 563b6d89e6ee7fc104204ce7cd178b56c35b633d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 11:58:31 +0900 Subject: [PATCH 3/6] Fix undefined LlamaForCausalLM and del try except --- src/axolotl/utils/models.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b3a5eeb60..43d4a6a9c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -81,7 +81,6 @@ def load_model( Load a model from a base model and a model type. """ - global LlamaForCausalLM # pylint: disable=global-statement # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit cfg.is_llama_derived_model = "llama" in base_model or ( @@ -203,12 +202,7 @@ def load_model( ) load_in_8bit = False elif cfg.is_llama_derived_model: - try: - from transformers import LlamaForCausalLM - except ImportError: - logging.warning( - "This version of transformers does not support Llama. Consider upgrading." - ) + from transformers import LlamaForCausalLM config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( From a6190c8094ead6d31570d597b6f2dfdbe83f50ff Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 11:59:03 +0900 Subject: [PATCH 4/6] Clean up landmark patching --- .../monkeypatch/llama_landmark_attn.py | 439 ++---------------- 1 file changed, 37 insertions(+), 402 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 1a130f755..51f1b90fe 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -28,15 +28,23 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN +from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, ) -from transformers.modeling_utils import PreTrainedModel from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LLAMA_START_DOCSTRING, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + _expand_mask, + _make_causal_mask, + rotate_half, +) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -51,131 +59,6 @@ _CONFIG_FOR_DOC = "LlamaConfig" MEM_TOKEN = "" # nosec -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full( - (tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device, - ) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype, - ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos()[None, None, :, :], persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin()[None, None, :, :], persistent=False - ) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype - ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer( - "cos_cached", emb.cos()[None, None, :, :], persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin()[None, None, :, :], persistent=False - ) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -190,24 +73,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -class LlamaMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - class LandmarkGroupedSoftmaxFunction(torch.autograd.Function): + """ + Landmark grouped softmax function. + """ + # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(ctx, x, dim, mem_cnt, resp_mem_idx): @@ -682,16 +552,14 @@ class LlamaAttention(nn.Module): # upcast attention to fp32 if is_mem is None: raise ValueError("Don't use this without landmarks") - # attn_weights = nn.functional.softmax( - # attn_weights, dim=-1, dtype=torch.float32 - # ).to(query_states.dtype) - else: - attn_weights = landmark_grouped_softmax( - attn_weights, - dim=-1, - is_mem=is_mem.expand(-1, self.num_heads, -1, -1), - last_section_mask=last_section_mask, - ).to(query_states.dtype) + + attn_weights = landmark_grouped_softmax( + attn_weights, + dim=-1, + is_mem=is_mem.expand(-1, self.num_heads, -1, -1), + last_section_mask=last_section_mask, + ).to(query_states.dtype) + if attn_prefix is not None: attn_prefix, attn_weights = torch.split( attn_weights, @@ -722,6 +590,10 @@ class LlamaAttention(nn.Module): class LlamaDecoderLayer(nn.Module): + """ + Llama Decoder layer + """ + def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size @@ -802,114 +674,6 @@ class LlamaDecoderLayer(nn.Module): return outputs -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, @@ -1178,6 +942,10 @@ class LlamaModel(LlamaPreTrainedModel): class LlamaForCausalLM(LlamaPreTrainedModel): + """ + Llama model with a causal language modeling head. + """ + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) @@ -1448,149 +1216,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel): return reordered_past -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = ( - torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - ).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct( - pooled_logits.view(-1, self.num_labels), labels.view(-1) - ) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def add_mem_tokens(example, mem_freq, mem_id): - x = example["input_ids"] + ids = example["input_ids"] ret = [] prev_idx = 0 - for t_idx in range(mem_freq, len(x), mem_freq): - ret.extend(x[prev_idx:t_idx]) + for t_idx in range(mem_freq, len(ids), mem_freq): + ret.extend(ids[prev_idx:t_idx]) ret.append(mem_id) prev_idx = t_idx - ret.extend(x[prev_idx:]) + ret.extend(ids[prev_idx:]) # drop attention_mask return {"input_ids": ret} @@ -1602,3 +1236,4 @@ def patch_llama_with_landmark_attn(): transformers.models.llama.modeling_llama.LlamaModel = LlamaModel transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer + transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb From 572d1141e67789dfdb78dca35b6be6945d4fb782 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 12:05:37 +0900 Subject: [PATCH 5/6] Set mem cache args on inference --- scripts/finetune.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/finetune.py b/scripts/finetune.py index 8a458890c..cdc4b5e0e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -77,6 +77,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): importlib.import_module("axolotl.prompters"), prompter ) + if cfg.landmark_attention: + model.set_mem_cache_args( + max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None + ) + while True: print("=" * 80) # support for multiline inputs @@ -90,6 +95,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): else: prompt = instruction.strip() batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + print("=" * 40) model.eval() with torch.no_grad(): From 974dc00a7d966e1b26c7e69aea378c8d325776c8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 14:00:54 +0900 Subject: [PATCH 6/6] Fix set mem_id for inference and refactor --- scripts/finetune.py | 3 +++ src/axolotl/monkeypatch/llama_landmark_attn.py | 10 ++++++++++ src/axolotl/utils/trainer.py | 11 +++++++---- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index cdc4b5e0e..4875256ba 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): ) if cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id + + set_model_mem_id(model, tokenizer) model.set_mem_cache_args( max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None ) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 51f1b90fe..2a4cdbc36 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -29,6 +29,7 @@ import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from transformers import LlamaTokenizer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn(): transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb + + +def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer): + mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) + model.set_mem_id(mem_id) + + +def get_mem_id(tokenizer: LlamaTokenizer): + return tokenizer.convert_tokens_to_ids(MEM_TOKEN) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ae1e7e93..1250ad4f6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.is_llama_derived_model and cfg.landmark_attention: from functools import partial - from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens + from axolotl.monkeypatch.llama_landmark_attn import ( + add_mem_tokens, + get_mem_id, + set_model_mem_id, + ) - mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) - model.set_mem_id(mem_id) + set_model_mem_id(model, tokenizer) logging.info("Adding landmark attention tokens to dataset") for dataset in [train_dataset, eval_dataset]: dataset = dataset.map( - partial(add_mem_tokens, mem_freq=50, mem_id=mem_id), + partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)), batched=False, num_proc=32, )