diff --git a/README.md b/README.md
index db884ec6b..2b14fe94b 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
-| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
+| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
## Quickstart ⚡
@@ -33,6 +33,7 @@
git clone https://github.com/OpenAccess-AI-Collective/axolotl
pip3 install -e .
+pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config
@@ -53,6 +54,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
```
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
+ - `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
- `winglian/axolotl:dev`: dev branch (not usually up to date)
Or run on the current files for development:
@@ -67,9 +69,19 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following:
- - `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
- - `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
- - `pip3 install -e .[gptq_triton]`
+ - Recommended, supports QLoRA, NO gptq/int4 support
+ ```bash
+ pip3 install -e .
+ pip3 install -U git+https://github.com/huggingface/peft.git
+ ```
+ - gptq/int4 support, NO QLoRA
+ ```bash
+ pip3 install -e .[gptq]
+ ```
+ - same as above but not recommended
+ ```bash
+ pip3 install -e .[gptq_triton]
+ ```
- LambdaLabs
@@ -78,7 +90,8 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
1. Install python
```bash
- sudo apt install python3.9
+ sudo apt update
+ sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
@@ -205,14 +218,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
-- custom prompts structure:
- 1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
- 2. Use your custom file name as the dataset type.
+#### How to add custom prompts
+
+ 1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
+ 2. Use your custom file name as the dataset type.
+
Optionally, download some datasets, see [data/README.md](data/README.md)
+
+
### Config
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
@@ -370,7 +387,7 @@ train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
-# does not work with current implementation of 4-bit LoRA
+# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
@@ -400,6 +417,8 @@ flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
+# Landmark attention (only llama)
+landmark_attention:
# resume from a specific checkpoint dir
resume_from_checkpoint:
diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml
new file mode 100644
index 000000000..6168ff2d5
--- /dev/null
+++ b/examples/falcon/config-7b-qlora.yml
@@ -0,0 +1,92 @@
+# 1b: tiiuae/falcon-rw-1b
+# 40b: tiiuae/falcon-40b
+base_model: tiiuae/falcon-7b
+base_model_config: tiiuae/falcon-7b
+# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
+trust_remote_code: true
+model_type: AutoModelForCausalLM
+tokenizer_type: AutoTokenizer
+load_in_8bit: false
+# enable 4bit for QLoRA
+load_in_4bit: true
+gptq: false
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: QingyiSi/Alpaca-CoT
+ data_files:
+ - Chain-of-Thought/formatted_cot_data/gsm8k_train.json
+ type: "alpaca:chat"
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+# enable QLoRA
+adapter: qlora
+lora_model_dir:
+sequence_len: 2048
+max_packed_sequence_len:
+
+# hyperparameters from QLoRA paper Appendix B.2
+# "We find hyperparameters to be largely robust across datasets"
+lora_r: 64
+lora_alpha: 16
+# 0.1 for models up to 13B
+# 0.05 for 33B and 65B models
+lora_dropout: 0.05
+# add LoRA modules on all linear layers of the base model
+lora_target_modules:
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./qlora-out
+
+# QLoRA paper Table 9
+# - 16 for 7b & 13b
+# - 32 for 33b, 64 for 64b
+# Max size tested on A6000
+# - 7b: 40
+# - 40b: 4
+# decrease if OOM, increase for max VRAM utilization
+micro_batch_size: 1
+gradient_accumulation_steps: 2
+num_epochs: 3
+# Optimizer for QLoRA
+optimizer: paged_adamw_32bit
+torchdistx_path:
+lr_scheduler: cosine
+# QLoRA paper Table 9
+# - 2e-4 for 7b & 13b
+# - 1e-4 for 33b & 64b
+learning_rate: 0.0002
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: true
+gradient_checkpointing: true
+# stop training after this many evaluation losses have increased in a row
+# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
+early_stopping_patience: 3
+resume_from_checkpoint:
+auto_resume_from_checkpoints: true
+local_rank:
+logging_steps: 1
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 5
+save_steps: 10
+debug:
+deepspeed:
+weight_decay: 0.000001
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: "<|endoftext|>"
+ bos_token: ">>ABSTRACT<<"
+ eos_token: "<|endoftext|>"
diff --git a/scripts/finetune.py b/scripts/finetune.py
index faf1bb31d..3222afd81 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
-from transformers import GenerationConfig
+from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
- tokenizer.add_special_tokens({"unk_token": ""})
- tokenizer.add_special_tokens({"bos_token": ""})
- tokenizer.add_special_tokens({"eos_token": ""})
+ default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""}
+
+ for token, symbol in default_tokens.items():
+ # If the token isn't already specified in the config, add it
+ if not (cfg.special_tokens and token in cfg.special_tokens):
+ tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
+ print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
-
+ print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False,
output_scores=False,
)
+ streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
+ streamer=streamer,
)
+ print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -183,6 +190,9 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
+ if cfg.tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py
new file mode 100644
index 000000000..18e913f09
--- /dev/null
+++ b/src/axolotl/monkeypatch/llama_landmark_attn.py
@@ -0,0 +1,1595 @@
+# pylint: skip-file
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch LLaMA model.
+Taken from https://github.com/epfml/landmark-attention/blob/main/llama/llama_mem.py and modified.
+"""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+MEM_TOKEN = "" # nosec
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full(
+ (tgt_len, tgt_len),
+ torch.tensor(torch.finfo(dtype).min, device=device),
+ device=device,
+ )
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat(
+ [
+ torch.zeros(
+ tgt_len, past_key_values_length, dtype=dtype, device=device
+ ),
+ mask,
+ ],
+ dim=-1,
+ )
+ return mask[None, None, :, :].expand(
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(
+ self.max_seq_len_cached,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype,
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
+ )
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
+ )
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ if q is None:
+ q_embed = None
+ else:
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
+ new_shape = list(x.shape)
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
+ max_by_group = x.new_zeros((*new_shape,))
+ max_by_group.scatter_reduce_(
+ src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False
+ )
+
+ maxes = torch.gather(max_by_group, dim, resp_mem_idx)
+ # x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes))
+ x_exp = torch.exp((x - maxes).to(torch.float32))
+
+ cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype)
+
+ cumsum_by_group.scatter_add_(
+ dim,
+ resp_mem_idx,
+ x_exp,
+ )
+ denom = torch.gather(cumsum_by_group, dim, resp_mem_idx)
+
+ # probs = torch.where(denom < 0.5, 0, x_exp / denom)
+ probs = x_exp / denom
+
+ ctx.mem_cnt = mem_cnt
+ ctx.dim = dim
+ ctx.save_for_backward(resp_mem_idx, probs)
+
+ return probs
+
+ @staticmethod
+ def backward(ctx, grad_probs):
+ mem_cnt = ctx.mem_cnt
+ dim = ctx.dim
+ resp_mem_idx, probs = ctx.saved_tensors
+ grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]:
+ grad_pair = grad_probs * probs
+
+ new_shape = list(probs.shape)
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
+ cumsum_by_group = grad_pair.new_zeros((*new_shape,))
+ cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair)
+
+ if ctx.needs_input_grad[0]:
+ grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx)
+ grad_x = grad_pair - probs * grad_sum
+ assert not ctx.needs_input_grad[1]
+ assert not ctx.needs_input_grad[2]
+ assert not ctx.needs_input_grad[3]
+
+ return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx
+
+
+def landmark_grouped_softmax(x, dim, is_mem, last_section_mask):
+ last_and_rest_mask = last_section_mask # | mask
+
+ full_access_mask = is_mem | last_and_rest_mask
+
+ max_mem_cnt = 16
+ mem_group_idx = torch.cumsum(is_mem, dim=dim)
+ mem_bucket_id = max_mem_cnt - 1
+ resp_mem_idx = torch.where(
+ last_and_rest_mask,
+ max_mem_cnt - 1,
+ torch.where(is_mem, mem_bucket_id, mem_group_idx),
+ )
+ probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx)
+
+ new_shape = list(x.shape)
+ new_shape[dim] = max_mem_cnt
+ group_prob = probs.new_zeros((*new_shape,))
+ group_prob.scatter_(
+ dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs
+ )
+ probs = probs.mul(
+ torch.where(
+ full_access_mask,
+ last_section_mask,
+ torch.gather(group_prob, dim, resp_mem_idx),
+ )
+ )
+
+ return probs
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.o_proj = nn.Linear(
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
+ )
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
+ )
+
+ self.mem_freq = None
+ self.top_k = None
+ self.max_cache_size = None
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def set_mem_cache_args(self, mem_freq, top_k, max_cache_size):
+ self.mem_freq = mem_freq
+ self.top_k = top_k
+ self.max_cache_size = max_cache_size
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ is_mem: Optional[torch.Tensor] = None,
+ last_section_mask: Optional[torch.Tensor] = None,
+ offload_cache_to_cpu: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ if len(past_key_value) > 2:
+ kv_seq_len += past_key_value[3].shape[2] * past_key_value[3].shape[3]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ key_states_before_pos = key_states
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ attn_prefix = None
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ if self.mem_freq is None:
+ cache_len = past_key_value[0].shape[2]
+ if self.max_cache_size is not None:
+ cache_len = min(cache_len, self.max_cache_size)
+ if is_mem is not None:
+ is_mem = torch.cat(
+ (is_mem.new_zeros((1, 1, q_len, cache_len)), is_mem), dim=-1
+ )
+ last_section_mask = torch.cat(
+ (
+ last_section_mask.new_ones((1, 1, q_len, cache_len)),
+ last_section_mask,
+ ),
+ dim=-1,
+ )
+
+ past_key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ past_value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ key_states = past_key_states[:, :, -(q_len + cache_len) :]
+ value_states = past_value_states[:, :, -(q_len + cache_len) :]
+ expected_att_size = (bsz, self.num_heads, q_len, cache_len + q_len)
+ else:
+ orig_value_states = value_states
+
+ incomplete_len = past_key_value[0].shape[2] % (self.mem_freq + 1)
+ full_len = past_key_value[0].shape[2] - incomplete_len
+ past_key_mem, past_key_incomplete = torch.split(
+ past_key_value[0], (full_len, incomplete_len), dim=2
+ )
+ past_value_mem, past_value_incomplete = torch.split(
+ past_key_value[1], (full_len, incomplete_len), dim=2
+ )
+
+ if offload_cache_to_cpu:
+ past_key_value = (
+ past_key_incomplete,
+ past_value_incomplete,
+ *past_key_value[2:],
+ )
+
+ if incomplete_len > 0:
+ assert q_len + incomplete_len <= (self.mem_freq + 1)
+ is_mem = torch.cat(
+ (is_mem.new_zeros((1, 1, q_len, incomplete_len)), is_mem), dim=-1
+ )
+ last_section_mask = torch.cat(
+ (
+ last_section_mask.new_ones((1, 1, q_len, incomplete_len)),
+ last_section_mask,
+ ),
+ dim=-1,
+ )
+
+ if len(past_key_value) > 2:
+ full_len += past_key_value[3].shape[2] * past_key_value[3].shape[3]
+ past_key_incomplete_pos = torch.arange(
+ full_len,
+ full_len + incomplete_len,
+ dtype=torch.long,
+ device=position_ids.device,
+ ).unsqueeze(0)
+ _, past_key_incomplete = apply_rotary_pos_emb(
+ None, past_key_incomplete, cos, sin, past_key_incomplete_pos
+ )
+ key_states = torch.cat((past_key_incomplete, key_states), dim=2)
+ value_states = torch.cat((past_value_incomplete, value_states), dim=2)
+
+ past_key_mem = past_key_mem.view(
+ bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim
+ )
+ past_value_mem = past_value_mem.view(
+ bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim
+ )
+
+ if len(past_key_value) > 2:
+ mem_key_nopos = torch.cat(
+ (
+ past_key_value[2],
+ past_key_mem.select(dim=3, index=self.mem_freq),
+ ),
+ dim=2,
+ )
+ past_key_mem_offload = past_key_value[3]
+ past_key_mem = torch.cat(
+ (
+ past_key_mem_offload,
+ past_key_mem.to(past_key_mem_offload.device),
+ ),
+ dim=2,
+ )
+ past_value_mem = torch.cat(
+ (
+ past_key_value[4],
+ past_value_mem.to(past_key_mem_offload.device),
+ ),
+ dim=2,
+ )
+ else:
+ mem_key_nopos = past_key_mem.select(dim=3, index=self.mem_freq)
+
+ num_mems = past_key_mem.shape[2]
+ top_k = min(self.top_k, num_mems)
+ prefix_len = full_len - (top_k + 1) * (self.mem_freq + 1)
+ mem_indices = torch.cat(
+ (
+ position_ids.new_zeros((max(0, num_mems - top_k),)),
+ torch.arange(
+ 1,
+ top_k + 1,
+ device=query_states.device,
+ dtype=position_ids.dtype,
+ ),
+ ),
+ dim=0,
+ )
+ mem_pos = (mem_indices * (self.mem_freq + 1) + self.mem_freq).unsqueeze(
+ 0
+ ).expand(bsz, -1) + prefix_len
+ _, mem_key = apply_rotary_pos_emb(
+ None, mem_key_nopos, cos, sin, mem_pos
+ )
+ mem_attn_weights = torch.matmul(
+ query_states, mem_key.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if offload_cache_to_cpu:
+ aggregate = "max_over_tokens"
+ else:
+ aggregate = None
+ if aggregate == "max_over_tokens":
+ token_retrievers = 1
+ head_retrievers = self.num_heads
+ mem_attn_weights = torch.nn.functional.softmax(
+ mem_attn_weights, dim=-1
+ )
+ mem_attn_weights = mem_attn_weights.amax(dim=2, keepdim=True)
+ elif aggregate is None:
+ token_retrievers = q_len
+ head_retrievers = self.num_heads
+ else:
+ raise NotImplementedError()
+
+ mem_selected_idx = (
+ mem_attn_weights.topk(dim=-1, k=top_k)[1]
+ .sort(dim=-1)[0]
+ .view(bsz, head_retrievers, token_retrievers, top_k)
+ )
+
+ selected_indices = torch.arange(
+ 0,
+ top_k * (self.mem_freq + 1),
+ device=query_states.device,
+ dtype=position_ids.dtype,
+ )
+ selected_indices = torch.where(
+ mem_selected_idx >= num_mems - top_k, self.mem_freq + 1, 0
+ ).unsqueeze(-1) + selected_indices.view(
+ 1, 1, 1, top_k, self.mem_freq + 1
+ )
+ selected_indices = (
+ selected_indices.view(
+ bsz, head_retrievers, token_retrievers, -1
+ ).expand(bsz, self.num_heads, q_len, -1)
+ + prefix_len
+ )
+
+ mem_selected_idx = mem_selected_idx.to(past_key_mem.device)
+
+ mem_selected_idx = mem_selected_idx.view(
+ bsz, self.num_heads, token_retrievers, top_k, 1, 1
+ ).expand(
+ bsz,
+ self.num_heads,
+ token_retrievers,
+ top_k,
+ self.mem_freq + 1,
+ self.head_dim,
+ )
+ selected_keys = past_key_mem.unsqueeze(2).expand(
+ bsz,
+ self.num_heads,
+ token_retrievers,
+ -1,
+ self.mem_freq + 1,
+ self.head_dim,
+ )
+ selected_keys = selected_keys.take_along_dim(
+ mem_selected_idx, dim=3
+ ).to(query_states.device)
+ selected_values = (
+ past_value_mem.unsqueeze(2)
+ .expand(
+ bsz,
+ self.num_heads,
+ token_retrievers,
+ -1,
+ self.mem_freq + 1,
+ self.head_dim,
+ )
+ .take_along_dim(mem_selected_idx, dim=3)
+ .to(query_states.device)
+ )
+
+ selected_keys = selected_keys.view(
+ bsz, self.num_heads, token_retrievers, -1, self.head_dim
+ ).expand(bsz, self.num_heads, q_len, -1, self.head_dim)
+ selected_keys = apply_rotary_pos_emb(
+ None, selected_keys.unsqueeze(1), cos, sin, selected_indices
+ )[1].squeeze(1)
+ selected_values = selected_values.view(
+ bsz, self.num_heads, token_retrievers, -1, self.head_dim
+ ).expand(bsz, self.num_heads, q_len, -1, self.head_dim)
+ attn_prefix = torch.matmul(
+ query_states.unsqueeze(3), selected_keys.transpose(3, 4)
+ ).squeeze(3) / math.sqrt(self.head_dim)
+ is_mem_prefix = (
+ torch.cat(
+ (is_mem.new_zeros((self.mem_freq,)), is_mem.new_ones((1,)))
+ )
+ .unsqueeze(0)
+ .repeat((top_k, 1))
+ )
+ is_mem_prefix = is_mem_prefix.view(1, 1, 1, -1).expand(1, 1, q_len, -1)
+ is_mem = torch.cat((is_mem_prefix, is_mem), dim=-1)
+ last_section_mask = torch.cat(
+ (
+ last_section_mask.new_zeros(
+ (1, 1, q_len, top_k * (self.mem_freq + 1))
+ ),
+ last_section_mask,
+ ),
+ dim=-1,
+ )
+ expected_att_size = (bsz, self.num_heads, q_len, q_len + incomplete_len)
+
+ past_key_states = torch.cat(
+ [past_key_value[0], key_states_before_pos], dim=2
+ )
+ past_value_states = torch.cat(
+ [past_key_value[1], orig_value_states], dim=2
+ )
+
+ if offload_cache_to_cpu:
+ past_key_value = (
+ (
+ past_key_states,
+ past_value_states,
+ mem_key_nopos,
+ past_key_mem.to("cpu"),
+ past_value_mem.to("cpu"),
+ *past_key_value[5:],
+ )
+ if use_cache
+ else None
+ )
+ else:
+ past_key_value = (
+ (past_key_states, past_value_states) if use_cache else None
+ )
+
+ else:
+ if self.mem_freq is None:
+ past_key_states = key_states
+ else:
+ past_key_states = key_states_before_pos
+ past_value_states = value_states
+ expected_att_size = (bsz, self.num_heads, q_len, kv_seq_len)
+ past_key_value = (past_key_states, past_value_states) if use_cache else None
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+ if attn_weights.size() != expected_att_size:
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask[..., -attn_weights.shape[-1] :]
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+ if attn_prefix is not None:
+ attn_weights = torch.cat((attn_prefix, attn_weights), dim=-1)
+ # upcast attention to fp32
+ if is_mem is None:
+ raise ValueError("Don't use this without landmarks")
+ # attn_weights = nn.functional.softmax(
+ # attn_weights, dim=-1, dtype=torch.float32
+ # ).to(query_states.dtype)
+ else:
+ attn_weights = landmark_grouped_softmax(
+ attn_weights,
+ dim=-1,
+ is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
+ last_section_mask=last_section_mask,
+ ).to(query_states.dtype)
+ if attn_prefix is not None:
+ attn_prefix, attn_weights = torch.split(
+ attn_weights,
+ (attn_prefix.shape[-1], attn_weights.shape[-1] - attn_prefix.shape[-1]),
+ dim=-1,
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+ if attn_prefix is not None:
+ attn_output += torch.matmul(
+ attn_prefix.unsqueeze(3), selected_values
+ ).squeeze(3)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def set_mem_cache_args(self, mem_freq, top_k, max_cache_size):
+ self.self_attn.set_mem_cache_args(mem_freq, top_k, max_cache_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ is_mem: Optional[torch.Tensor] = None,
+ last_section_mask: Optional[torch.Tensor] = None,
+ offload_cache_to_cpu: bool = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ is_mem=is_mem,
+ last_section_mask=last_section_mask,
+ offload_cache_to_cpu=offload_cache_to_cpu,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.mem_id = None
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def set_mem_id(self, mem_id):
+ self.mem_id = mem_id
+
+ def set_mem_cache_args(self, mem_freq, top_k, max_cache_size):
+ for layer in self.layers:
+ layer.set_mem_cache_args(mem_freq, top_k, max_cache_size)
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ offload_cache_to_cpu: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # retrieve input_ids and inputs_embeds
+ is_mem = None
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ if self.mem_id is not None:
+ with torch.no_grad():
+ is_mem = input_ids == self.mem_id
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ if self.mem_id is not None:
+ raise NotImplementedError
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ if is_mem is not None:
+ pass
+ # raise NotImplementedError
+ past_key_values_length = past_key_values[0][0].shape[2]
+ if len(past_key_values[0]) > 2:
+ past_key_values_length += (
+ past_key_values[0][3].shape[2] * past_key_values[0][3].shape[3]
+ )
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ last_section_mask = None
+ if is_mem is not None:
+ is_mem = is_mem.unsqueeze(1).unsqueeze(2)
+ current_len = input_ids.shape[1]
+ mem_ids = torch.where(
+ attention_mask[..., -current_len:] < -1,
+ 0,
+ torch.cumsum(is_mem, -1) - is_mem.int(),
+ )
+ last_section_mask = torch.amax(mem_ids, -1, keepdim=True) == mem_ids
+ attention_mask[..., -current_len:].masked_fill_(
+ last_section_mask & is_mem,
+ torch.tensor(
+ torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device
+ ),
+ )
+ last_section_mask.logical_and_(attention_mask[..., -current_len:] > -1)
+ is_mem = is_mem.logical_and(attention_mask[..., -current_len:] > -1)
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ output_attentions,
+ None,
+ is_mem,
+ last_section_mask,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ is_mem=is_mem,
+ last_section_mask=last_section_mask,
+ offload_cache_to_cpu=offload_cache_to_cpu,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.mem_id = None
+ self.mem_freq = None
+ self.top_k = None
+ self.max_seq_len = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ offload_cache_to_cpu: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ window_len = self.max_seq_len or input_ids.shape[1]
+ last_logits = None
+ for _, idx in enumerate(range(0, input_ids.shape[1], window_len)):
+ if idx >= 1:
+ if output_attentions or output_hidden_states:
+ raise NotImplementedError
+ if not use_cache:
+ raise NotImplementedError
+ outputs = self.model(
+ input_ids=input_ids[:, idx : idx + window_len],
+ attention_mask=attention_mask[
+ :, : idx + window_len + attention_mask.shape[1] - input_ids.shape[1]
+ ]
+ if attention_mask is not None
+ else None,
+ position_ids=position_ids[:, idx : idx + window_len]
+ if position_ids is not None
+ else None,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds[:, idx : idx + window_len]
+ if inputs_embeds is not None
+ else None,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ offload_cache_to_cpu=offload_cache_to_cpu,
+ )
+ past_key_values = outputs.past_key_values
+ if last_logits is not None:
+ last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
+ last_logits = outputs[0]
+
+ hidden_states = last_logits
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def set_mem_id(self, mem_id):
+ self.mem_id = mem_id
+ self.model.set_mem_id(mem_id)
+
+ def set_mem_cache_args(self, max_seq_len, mem_freq, top_k, max_cache_size):
+ self.mem_freq = mem_freq
+ self.top_k = top_k
+ self.max_seq_len = max_seq_len
+ if self.max_seq_len is not None:
+ assert self.max_seq_len % (self.mem_freq + 1) == 0
+ self.model.set_mem_cache_args(mem_freq, top_k, max_cache_size)
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ **kwargs,
+ ):
+ total_len = input_ids.shape[1]
+ if past_key_values:
+ prev_len = input_ids.shape[1] - 1
+ else:
+ prev_len = 0
+
+ position_ids = kwargs.get("position_ids", None)
+
+ if self.mem_freq is not None:
+ if position_ids is not None:
+ raise NotImplementedError
+ # T = input_ids.shape[1]
+
+ prev_incomplete_len = prev_len % self.mem_freq
+ prev_complete_len = prev_len - prev_incomplete_len
+ incomplete_len = total_len % self.mem_freq
+ new_full_len = total_len - prev_complete_len - incomplete_len
+
+ prev_input, input_ids_with_mem, input_ids_without_mem = torch.split(
+ input_ids, (prev_complete_len, new_full_len, incomplete_len), dim=-1
+ )
+
+ bsz, _ = input_ids.size()
+ input_ids_with_mem = input_ids_with_mem.view(bsz, -1, self.mem_freq)
+ input_ids_with_mem = torch.cat(
+ (
+ input_ids_with_mem,
+ input_ids_with_mem.new_full(
+ (bsz, input_ids_with_mem.shape[1], 1), self.mem_id
+ ),
+ ),
+ dim=-1,
+ ).view(bsz, -1)
+ input_ids = torch.cat(
+ (prev_input, input_ids_with_mem, input_ids_without_mem), dim=-1
+ )
+ if attention_mask is not None:
+ attention_mask_with_mem, attention_mask_without_mem = torch.split(
+ attention_mask,
+ (prev_complete_len + new_full_len, incomplete_len),
+ dim=-1,
+ )
+ attention_mask_with_mem = attention_mask_with_mem.view(
+ bsz, -1, self.mem_freq
+ )
+ attention_mask_with_mem = torch.cat(
+ (
+ attention_mask_with_mem,
+ attention_mask_with_mem.new_ones(
+ (bsz, attention_mask_with_mem.shape[1], 1)
+ ),
+ ),
+ dim=-1,
+ ).view(bsz, -1)
+ attention_mask = torch.cat(
+ (attention_mask_with_mem, attention_mask_without_mem), dim=-1
+ )
+
+ input_ids = input_ids[:, prev_len:]
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids[:, -input_ids.shape[1] :].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if (
+ inputs_embeds is not None
+ and past_key_values is None
+ and self.mem_freq is None
+ ):
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "offload_cache_to_cpu": kwargs.get("offload_cache_to_cpu"),
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ "Cannot handle batch sizes > 1 if no padding token is defined."
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+ ).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[
+ torch.arange(batch_size, device=logits.device), sequence_lengths
+ ]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
+ )
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+def add_mem_tokens(example, mem_freq, mem_id):
+ x = example["input_ids"]
+ ret = []
+ prev_idx = 0
+ for t_idx in range(mem_freq, len(x), mem_freq):
+ ret.extend(x[prev_idx:t_idx])
+ ret.append(mem_id)
+ prev_idx = t_idx
+ ret.extend(x[prev_idx:])
+ # drop attention_mask
+ return {"input_ids": ret}
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index 037fa45bf..cba964076 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -78,6 +78,13 @@ def load_tokenized_prepared_datasets(
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
+
+ if cfg.seed:
+ seed = cfg.seed
+ else:
+ logging.info("No seed provided, using default seed of 42")
+ seed = 42
+
datasets = []
# pylint: disable=invalid-name
for d in cfg.datasets:
@@ -127,11 +134,11 @@ def load_tokenized_prepared_datasets(
# support for using a subset of the data
if d.shards:
if "train" in ds:
- ds = ds.shuffle(seed=42)["train"].shard(
+ ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0
)
else:
- ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
+ ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -239,7 +246,7 @@ def load_tokenized_prepared_datasets(
samples: List[int] = []
for d in datasets:
samples = samples + list(d)
- dataset = Dataset.from_list(samples).shuffle(seed=42)
+ dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index c3f988e52..7156adec0 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
)
try:
- from transformers import LlamaForCausalLM
+ from transformers import ( # pylint: disable=unused-import # noqa: F401
+ LlamaForCausalLM,
+ )
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
@@ -82,37 +84,47 @@ def load_model(
cfg,
adapter="lora"
):
- # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
+ # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
- is_llama_derived_model = "llama" in base_model or (
+ cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
- if is_llama_derived_model and cfg.flash_attention:
- if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
+ if cfg.is_llama_derived_model and cfg.flash_attention:
+ if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
- elif is_llama_derived_model and cfg.xformers_attention:
+ elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
- elif is_llama_derived_model and cfg.sdp_attention:
+ elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
+ elif cfg.is_llama_derived_model and cfg.landmark_attention:
+ from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
+ MEM_TOKEN,
+ LlamaForCausalLM,
+ )
+
+ logging.info("patching with landmark attention")
+
+ # TODO: Check if this would overwrite previous additional_special_tokens
+ tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.bf16:
torch_dtype = torch.bfloat16
@@ -127,11 +139,18 @@ def load_model(
)
replace_peft_model_with_int4_lora_model()
- from peft import prepare_model_for_int8_training
except Exception as err:
logging.exception(err)
raise err
+ try:
+ from peft import prepare_model_for_kbit_training
+ except ImportError:
+ # For backward compatibility
+ from peft import (
+ prepare_model_for_int8_training as prepare_model_for_kbit_training,
+ )
+
model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -143,7 +162,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
- if cfg.gptq and is_llama_derived_model:
+ if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
@@ -181,7 +200,7 @@ def load_model(
else True,
)
load_in_8bit = False
- elif is_llama_derived_model and "LlamaForCausalLM" in globals():
+ elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
@@ -235,8 +254,15 @@ def load_model(
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
- if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
+ if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
config.max_seq_len = cfg.sequence_len
+ logging.warning(f"increasing context length to {cfg.sequence_len}")
+ elif (
+ hasattr(config, "max_sequence_length")
+ and cfg.sequence_len > config.max_sequence_length
+ ):
+ config.max_sequence_length = cfg.sequence_len
+ logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
@@ -268,8 +294,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
- logging.info("converting PEFT model w/ prepare_model_for_int8_training")
- model = prepare_model_for_int8_training(model)
+ logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
+ model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 2986c491b..9ae1e7e93 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions"""
import importlib
+import logging
import math
import os
import sys
@@ -62,8 +63,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
- save_steps = cfg.save_steps
- eval_steps = cfg.eval_steps
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -74,6 +73,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
+
+ if cfg.seed:
+ training_arguments_kwargs["seed"] = cfg.seed
+
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
@@ -119,16 +122,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
- save_strategy="steps" if save_steps else "epoch",
- eval_steps=eval_steps if cfg.val_set_size > 0 else None,
- save_steps=save_steps,
+ save_strategy="steps" if cfg.save_steps else "epoch",
+ eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
+ save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0
- and save_steps
- and save_steps % eval_steps == 0
+ and cfg.save_steps
+ and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
@@ -233,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
+ if cfg.is_llama_derived_model and cfg.landmark_attention:
+ from functools import partial
+
+ from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
+
+ mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
+ model.set_mem_id(mem_id)
+
+ logging.info("Adding landmark attention tokens to dataset")
+
+ for dataset in [train_dataset, eval_dataset]:
+ dataset = dataset.map(
+ partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
+ batched=False,
+ num_proc=32,
+ )
+
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py
index 38e0b9819..04ffc4c1b 100644
--- a/src/axolotl/utils/validation.py
+++ b/src/axolotl/utils/validation.py
@@ -54,6 +54,9 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
+ if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
+ raise ValueError("FSDP is not supported for falcon models")
+
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py
index 90e9c2f73..d22b932cb 100644
--- a/src/axolotl/utils/wandb.py
+++ b/src/axolotl/utils/wandb.py
@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
+ else:
+ os.environ["WANDB_DISABLED"] = "true"
diff --git a/tests/test_validation.py b/tests/test_validation.py
index ce744f762..50bdf37e6 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
+
+ def test_falcon_fsdp(self):
+ regex_exp = r".*FSDP is not supported for falcon models.*"
+
+ # Check for lower-case
+ cfg = DictDefault(
+ {
+ "base_model": "tiiuae/falcon-7b",
+ "fsdp": ["full_shard", "auto_wrap"],
+ }
+ )
+
+ with pytest.raises(ValueError, match=regex_exp):
+ validate_config(cfg)
+
+ # Check for upper-case
+ cfg = DictDefault(
+ {
+ "base_model": "Falcon-7b",
+ "fsdp": ["full_shard", "auto_wrap"],
+ }
+ )
+
+ with pytest.raises(ValueError, match=regex_exp):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "base_model": "tiiuae/falcon-7b",
+ }
+ )
+
+ validate_config(cfg)