Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
080612219b use even if not using sample packing 2023-10-13 17:54:35 -04:00
Wing Lian
f95858d369 alternate impl of NEFT 2023-10-13 17:45:24 -04:00
43 changed files with 149 additions and 509 deletions

View File

@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--peft_model_dir="./lora-out"
--lora_model_dir="./lora-out"
```
## Installation
@@ -297,24 +297,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts
For a dataset that is preprocessed for instruction purposes:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
Using yaml. Example:
```yaml
datasets:
- path: repo
type:
system_prompt: ""
field_system: system
format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]"
no_input_format: |-
User: {instruction}<|end_of_turn|>
Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
@@ -384,10 +385,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # qlora or leave blank for full finetune
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
```
@@ -531,15 +532,15 @@ total_num_tokens:
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
peft_model_dir:
lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
@@ -547,13 +548,13 @@ peft_target_modules:
# - gate_proj
# - down_proj
# - up_proj
peft_target_linear: # if true, will target all linear layers
lora_target_linear: # If true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
peft_modules_to_save:
lora_modules_to_save:
# - embed_tokens
# - lm_head
@@ -561,8 +562,7 @@ peft_modules_to_save:
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir:
peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
lora_fan_in_fan_out: false
# ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
@@ -870,7 +870,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA:
```bash
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
@@ -891,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
Add below flag to train command above
```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
```
If you run out of CUDA memory, you can try to merge in system RAM with

View File

@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 16

View File

@@ -22,7 +22,7 @@ dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 64

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 512
max_packed_sequence_len:
lora_r:

View File

@@ -18,7 +18,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8

View File

@@ -1,72 +0,0 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./relora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -20,7 +20,7 @@ sequence_len: 4096
sample_packing: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r:

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r: 8

View File

@@ -22,7 +22,7 @@ sample_packing: true
pad_to_sequence_len:
adapter:
peft_model_dir:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:

View File

@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len:
adapter: qlora
peft_model_dir:
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05

View File

@@ -13,7 +13,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64

View File

@@ -7,7 +7,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 512
lora_r: 16
lora_alpha: 32

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -8,7 +8,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -20,7 +20,7 @@ dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 8192
max_packed_sequence_len:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -46,7 +46,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.3.0",
"flash-attn>=2.2.1",
],
"deepspeed": [
"deepspeed",

View File

@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
noisy_embeddings_alpha: Optional[int] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if noisy_embeddings_alpha:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
)
else:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
llama_model_get_inputs_embeds
)
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -116,8 +126,6 @@ def flashattn_forward(
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
@@ -153,13 +161,6 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
@@ -318,10 +319,6 @@ def flashattn_forward(
else:
attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value
@@ -424,6 +421,28 @@ def generate_qkv(
)
def llama_model_get_inputs_embeds(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
noisy_embeddings_alpha: Optional[int] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
if noisy_embeddings_alpha:
input_mask = attention_mask.to(inputs_embeds) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * inputs_embeds.size(-1)
mag = noisy_embeddings_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
inputs_embeds += delta
return inputs_embeds
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
@@ -490,7 +509,8 @@ def llama_model_forward(
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
@@ -515,7 +535,6 @@ def llama_model_forward(
)
hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -573,10 +592,6 @@ def llama_model_forward(
hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask
def flashattn_forward(
self: OriginalMistralAttention,
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -150,41 +91,10 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids
)
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
@@ -210,13 +120,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
@@ -242,7 +146,6 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
@@ -254,7 +157,6 @@ def flashattn_forward(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
@@ -289,7 +191,6 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)

View File

@@ -1,6 +1,6 @@
"""Module for Alpaca prompt strategy classes"""
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional, Tuple
from typing import Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style),
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -121,18 +121,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
def validate_config(cfg):
if is_torch_bf16_gpu_available():
@@ -202,10 +190,7 @@ def validate_config(cfg):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):

View File

@@ -158,7 +158,7 @@ def load_tokenized_prepared_datasets(
token=use_auth_token,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError):
except FileNotFoundError:
pass
# prefer local dataset, even if hub exists

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
@@ -147,6 +151,7 @@ def load_model(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -180,16 +185,16 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
# from axolotl.monkeypatch.llama_embeddings_hijack import (
# replace_llama_embeddings_with_uniform_distribution,
# )
#
# LOG.info("patching with noisy embeddings")
# replace_llama_embeddings_with_uniform_distribution(
# noise_alpha=cfg.noisy_embedding_alpha
# )
#
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
@@ -406,21 +411,21 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch.float32)
require_peft: bool = False
if cfg.adapter in ["lora", "qlora", "ia3"]:
require_peft = True
if require_peft:
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -429,7 +434,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
model, peft_config = load_adapter(model, cfg, cfg.adapter)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -460,7 +465,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, peft_config
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
@@ -470,8 +475,6 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
@@ -490,11 +493,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM",
)
if cfg.peft_model_dir:
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
cfg.lora_model_dir,
torch_dtype=torch.float16,
)
else:
@@ -507,20 +510,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
peft_module_names = set()
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in peft_module_names: # needed for 16-bit
peft_module_names.remove("lm_head")
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(peft_module_names)
return list(lora_module_names)
def load_lora(model, cfg, inference=False):
@@ -528,68 +527,34 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model
peft_target_modules = list(cfg.peft_target_modules or [])
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.peft_target_linear:
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
peft_target_modules = list(set(peft_target_modules + linear_names))
lora_target_modules = list(set(lora_target_modules + linear_names))
peft_config = LoraConfig(
r=cfg.peft_r,
lora_alpha=cfg.peft_alpha,
target_modules=peft_target_modules,
lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
)
if cfg.peft_model_dir:
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
cfg.lora_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, peft_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
return model, lora_config

View File

@@ -423,9 +423,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
)
# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

View File

@@ -24,10 +24,6 @@ class TestLoraLlama(unittest.TestCase):
"""
def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
@@ -70,101 +66,6 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()

View File

@@ -1,48 +0,0 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64