Compare commits
28 Commits
packing-at
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1afbd8af2d | ||
|
|
b4f2eea2ed | ||
|
|
bbf88b02c1 | ||
|
|
64a8e04430 | ||
|
|
c8f7213bc6 | ||
|
|
b57238ecec | ||
|
|
918f1b0dfb | ||
|
|
c3fde36ada | ||
|
|
2bb0b78975 | ||
|
|
a276c9c88d | ||
|
|
7019509daa | ||
|
|
96bd6ae1c4 | ||
|
|
e37d9358e6 | ||
|
|
b5212068ac | ||
|
|
289d5c403d | ||
|
|
35c8b90306 | ||
|
|
fae6ed8092 | ||
|
|
94d03c8402 | ||
|
|
11ddccb80f | ||
|
|
964312199e | ||
|
|
718102271f | ||
|
|
f5c11f8262 | ||
|
|
fce40aab23 | ||
|
|
9c314101d5 | ||
|
|
e303d64728 | ||
|
|
b4d1d22782 | ||
|
|
9f99104038 | ||
|
|
36fefcf94b |
50
README.md
50
README.md
@@ -375,10 +375,14 @@ dataset_shard_idx:
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# soon to be DEPRECATED
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
@@ -404,11 +408,12 @@ lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: # 'checkpoint'
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
@@ -423,13 +428,16 @@ learning_rate: 0.00003
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
save_total_limit:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
@@ -475,6 +483,10 @@ landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -506,6 +518,9 @@ torchdistx_path:
|
||||
# Set padding for data collator to 'longest'
|
||||
collator_pad_to_longest:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
# Debug mode
|
||||
debug:
|
||||
|
||||
@@ -525,7 +540,14 @@ Run
|
||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU Config
|
||||
#### Multi-GPU
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
||||
```
|
||||
|
||||
##### Config
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -540,6 +562,18 @@ fsdp_config:
|
||||
|
||||
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
|
||||
@@ -23,6 +23,7 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -35,7 +36,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -24,6 +24,7 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -38,6 +38,7 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -24,6 +24,7 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -20,6 +20,7 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -32,7 +33,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -22,6 +22,7 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -18,6 +18,7 @@ lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -26,6 +26,7 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -38,7 +39,7 @@ lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
@@ -27,6 +27,7 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -39,7 +40,7 @@ lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
@@ -20,6 +20,7 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -22,6 +22,7 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -28,6 +28,7 @@ lora_target_modules:
|
||||
- o_proj
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -22,6 +22,7 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -34,7 +35,7 @@ torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
@@ -23,6 +23,7 @@ lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -17,6 +17,7 @@ lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -21,6 +21,7 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -20,6 +20,7 @@ lora_target_modules:
|
||||
- mlp_down
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -37,6 +37,7 @@ lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
addict
|
||||
fire
|
||||
@@ -21,3 +21,4 @@ evaluate==0.4.0
|
||||
rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
|
||||
@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
@@ -268,16 +269,13 @@ def train(
|
||||
LOG.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and peft_config...")
|
||||
model, peft_config = load_model(
|
||||
cfg.base_model,
|
||||
cfg.base_model_config,
|
||||
cfg.model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
)
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
@@ -286,7 +284,11 @@ def train(
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info("saving merged model")
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if cfg.inference:
|
||||
@@ -301,7 +303,7 @@ def train(
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(
|
||||
@@ -325,7 +327,7 @@ def train(
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
@@ -369,7 +371,13 @@ def train(
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
if cfg.adapter == "lora" and cfg.relora_steps:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.dataset = dataset
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
features = self.dataset.features.keys()
|
||||
num_proc = os.cpu_count()
|
||||
return iter(
|
||||
self.dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
@@ -91,7 +92,8 @@ def forward(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
elif position_ids.shape[0] == 1:
|
||||
# special handling using sample packing
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_q_lens = cu_q_lens.squeeze()
|
||||
@@ -100,6 +102,36 @@ def forward(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad,
|
||||
"nnz (three h d) -> nnz three h d",
|
||||
three=3,
|
||||
h=nheads,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
x_unpad,
|
||||
cu_q_lens,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
|
||||
302
src/axolotl/monkeypatch/relora.py
Normal file
302
src/axolotl/monkeypatch/relora.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# pylint: skip-file
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
elif key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
|
||||
assert os.path.exists(
|
||||
self.last_full_model
|
||||
), "for ReLORA base_model must be a local path"
|
||||
|
||||
self.num_lora_restarts = 0
|
||||
self.need_full_save = False
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
checkpoint_folder,
|
||||
reinit=True,
|
||||
quantized=self.quantised,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantised:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||
for path in chunks:
|
||||
shutil.move(path, checkpoint_folder)
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Dict[str, float],
|
||||
**_kwargs,
|
||||
):
|
||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
else:
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
if os.path.exists(index_path):
|
||||
data = json.load(open(index_path, "r"))
|
||||
return data["weight_map"]
|
||||
return {key + ".weight": model_name for key in keys}
|
||||
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
|
||||
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
|
||||
layer, peft.tuners.lora.Linear4bit
|
||||
):
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
|
||||
getattr(layer, "fan_in_fan_out", False),
|
||||
)
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
else:
|
||||
return layer.get_delta_weight()
|
||||
|
||||
|
||||
def merge_and_save(
|
||||
model: peft.LoraModel,
|
||||
model_src: str,
|
||||
model_dst: str,
|
||||
reinit: bool = False,
|
||||
quantized: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
):
|
||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||
|
||||
if not quantized:
|
||||
for key in key_list:
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
shard_paths = sharded_paths(model_src, key_list)
|
||||
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
for shard_path in unique_shards:
|
||||
out_tensors = {}
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for key in key_list:
|
||||
if (key + ".weight") not in shard_paths or shard_paths[
|
||||
key + ".weight"
|
||||
] != shard_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
orig_weight = in_tensors[key + ".weight"]
|
||||
old_dev = target.weight.device
|
||||
math_dev = "cpu" if cpu_offload else old_dev
|
||||
|
||||
update = lora_delta_weight(target).detach().to(math_dev)
|
||||
new_weight = orig_weight.to(math_dev) + update
|
||||
out_tensors[key + ".weight"] = new_weight
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
target.weight = (
|
||||
bnb.nn.Params4bit(
|
||||
new_weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=target.weight.compress_statistics,
|
||||
quant_type=target.weight.quant_type,
|
||||
)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
else:
|
||||
target.weight.data = new_weight.to(old_dev)
|
||||
|
||||
for key in in_tensors:
|
||||
if key not in out_tensors:
|
||||
out_tensors[key] = in_tensors[key]
|
||||
del in_tensors
|
||||
|
||||
out_shard_name = shard_path
|
||||
if out_shard_name.startswith("pytorch_model"):
|
||||
out_shard_name = (
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn)
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if len(unique_shards) > 1:
|
||||
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
|
||||
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
|
||||
@@ -95,9 +95,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
|
||||
self.turn_no_input_format = "User: {instruction}\nAssistant:"
|
||||
self.system_format = "System: {system}\n"
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
@@ -271,6 +271,11 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
@@ -327,7 +332,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_prompt():
|
||||
|
||||
23
src/axolotl/utils/bench.py
Normal file
23
src/axolotl/utils/bench.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
|
||||
def gpu_memory_usage(device):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
|
||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -11,6 +12,10 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the PEFT adapter"""
|
||||
@@ -28,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
kwargs["model"].save_pretrained(
|
||||
peft_model_path, save_safetensors=args.save_safetensors
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
@@ -67,3 +74,25 @@ class SaveBetterTransformerModelCallback(
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
|
||||
class PrintGPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to print GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.logged:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@@ -265,20 +270,12 @@ def load_tokenized_prepared_datasets(
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
LOG.info("tokenizing, merging, and shuffling master dataset")
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
samples: List[int] = []
|
||||
chunk_size = 1000
|
||||
for d in datasets:
|
||||
d_iter = iter(d)
|
||||
while True:
|
||||
chunk = list(itertools.islice(d_iter, chunk_size))
|
||||
if not chunk:
|
||||
break
|
||||
samples.extend(chunk)
|
||||
|
||||
LOG.info("shuffle")
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
|
||||
@@ -3,9 +3,7 @@ import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -80,6 +78,7 @@ def allocate(
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
@@ -105,8 +104,10 @@ def allocate(
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
@@ -148,14 +149,15 @@ class MultipackDistributedDataloader:
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
total_num_tokens: Optional[int] = None,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
lengths_series = (
|
||||
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
|
||||
self.lengths = (
|
||||
dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
)
|
||||
self.lengths: np.ndarray = lengths_series.values
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
@@ -170,17 +172,11 @@ class MultipackDistributedDataloader:
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.total_num_tokens = total_num_tokens
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
# for non-blocking batch creation
|
||||
self.batch_queue: queue.Queue = queue.Queue(
|
||||
maxsize=10
|
||||
) # Adjust maxsize as needed
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
@@ -192,83 +188,65 @@ class MultipackDistributedDataloader:
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
alloc_iter = iter(
|
||||
allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
for batch, tot_seqs, total_used, total_slots in alloc_iter:
|
||||
self.batch_queue.put([indices[b_idx] for b_idx in batch])
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used = total_used
|
||||
self.eff_total_slots = total_slots
|
||||
self.batch_queue.put(None) # Signal the end of batch generation
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
def _generate_batches_thread(self):
|
||||
try:
|
||||
self.generate_batches(set_stats=True)
|
||||
except Exception as e:
|
||||
LOG.error(f"Error in batch generation thread: {e}")
|
||||
self.batch_queue.put(
|
||||
None
|
||||
) # Signal the end of batch generation in case of error
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
# Start the batch generation in a separate thread
|
||||
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
|
||||
batch_gen_thread.start()
|
||||
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
while True:
|
||||
batch = self.batch_queue.get()
|
||||
if batch is None: # Sentinel value received, stop iteration
|
||||
break
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
break
|
||||
# Wait for the batch generation thread to finish
|
||||
batch_gen_thread.join(timeout=5)
|
||||
LOG.info(f"actual packing efficiency: {self.efficiency()}")
|
||||
return
|
||||
|
||||
def _len_est(self):
|
||||
if not self.total_num_tokens:
|
||||
self.total_num_tokens = np.sum(self.lengths)
|
||||
lengths_sum_per_device = self.total_num_tokens // self.device_count
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -83,19 +84,22 @@ def load_tokenizer(
|
||||
|
||||
|
||||
def load_model(
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
cfg, tokenizer
|
||||
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
adapter = cfg.adapter
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
cfg.is_llama_derived_model = (
|
||||
"llama" in base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or cfg.is_llama_derived_model is True
|
||||
or cfg.is_llama_derived_model
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
@@ -231,7 +235,9 @@ def load_model(
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config, rope_scaling=cfg.rope_scaling
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
@@ -336,6 +342,9 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
@@ -372,6 +381,9 @@ def load_model(
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
@@ -20,7 +21,9 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
PrintGPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
@@ -122,10 +125,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
train_data_total_num_tokens: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the total number of tokens in the train dataset"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -186,7 +185,6 @@ class AxolotlTrainer(Trainer):
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
total_num_tokens=self.args.train_data_total_num_tokens,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
@@ -209,7 +207,6 @@ class AxolotlTrainer(Trainer):
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
total_num_tokens=None,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
@@ -288,13 +285,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
total_num_tokens = (
|
||||
cfg.total_num_tokens
|
||||
if cfg.total_num_tokens
|
||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||
)
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
@@ -302,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
(
|
||||
math.floor(
|
||||
0.99
|
||||
* total_num_tokens
|
||||
* cfg.total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
/ cfg.sequence_len
|
||||
// cfg.batch_size
|
||||
// int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
@@ -313,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
@@ -345,6 +345,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
LOG.info(
|
||||
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||
)
|
||||
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
@@ -483,8 +484,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
|
||||
train_data_total_num_tokens=cfg.total_num_tokens,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -556,6 +556,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||
|
||||
if cfg.relora_steps:
|
||||
relora_steps = int(cfg.relora_steps)
|
||||
relora_warmup_steps = int(cfg.relora_warmup_steps)
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
|
||||
trainer_kwargs["optimizers"] = (
|
||||
optimizer,
|
||||
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
|
||||
)
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
|
||||
@@ -61,6 +61,9 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -110,6 +113,13 @@ def validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
|
||||
@@ -9,6 +9,8 @@ def setup_wandb_env_vars(cfg):
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
|
||||
Reference in New Issue
Block a user