Compare commits

...

51 Commits

Author SHA1 Message Date
Wing Lian
c3de28942c fix for gather across multiple gpus
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-29 06:57:28 -07:00
Wing Lian
45848a9285 gather benchmarks from all ranks 2023-08-28 11:29:59 -04:00
Wing Lian
d6cea18034 improve support for customized dataset for bench evals 2023-08-28 06:03:53 -04:00
Wing Lian
606846e0a5 missing transformers import 2023-08-28 05:43:19 -04:00
Wing Lian
a6c9223114 more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
8b16ecd448 updated dataset 2023-08-28 05:39:13 -04:00
Wing Lian
f5db88a10d fixes 2023-08-28 05:39:13 -04:00
Wing Lian
99d844f215 benchmark callback has its own dataloader and collator 2023-08-28 05:39:13 -04:00
Wing Lian
aefd4d74fa better handling when no subjects 2023-08-28 05:39:13 -04:00
Wing Lian
24b0e93235 dataset handling and aggregate across benchmark 2023-08-28 05:39:13 -04:00
Wing Lian
2455254b92 more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
918e040601 rename mmlu to bench 2023-08-28 05:39:13 -04:00
Wing Lian
ef062d8fcb more fixes 2023-08-28 05:39:13 -04:00
Wing Lian
d4c8b66f3d fix elif and add better messaging 2023-08-28 05:39:13 -04:00
Wing Lian
64e9824d3e fix the data file 2023-08-28 05:39:13 -04:00
Wing Lian
1134654c98 sample benchmarks, ensure we drop long samples 2023-08-28 05:39:13 -04:00
Wing Lian
2fc756c289 fix mmlu evals 2023-08-28 05:39:13 -04:00
Wing Lian
943b84c490 another callback fix for collator max len attribute 2023-08-28 05:39:13 -04:00
Wing Lian
6f166464d8 include metrics in callback 2023-08-28 05:39:13 -04:00
Wing Lian
e3b07402a7 make sure to define all the explicit positional args 2023-08-28 05:39:13 -04:00
Wing Lian
8d3c8a3eab default to mmlu-zs 2023-08-28 05:39:13 -04:00
Wing Lian
c30120e684 use hf dataset for mmlu evals 2023-08-28 05:39:13 -04:00
Wing Lian
9aed60fa54 add mmlu callback 2023-08-28 05:39:12 -04:00
Wing Lian
98bf76e236 fsdp requires params be the same type too (#493) 2023-08-28 04:33:50 -04:00
NanoCode012
4c37bd0b54 Fix(tokenizer): Make sure to add pad for CodeLlamaTokenizer (#489) 2023-08-28 09:39:10 +09:00
Aman Gupta Karmani
f144e98a32 Merge pull request #485 from maximegmd/patch-4
fix: finetune model inference needs the dtype fix to work with flash-attn
2023-08-27 16:27:47 -04:00
Aman Karmani
3a011ea1ef fix condition and add logging 2023-08-27 20:09:26 +00:00
Aman Karmani
1f613e5aa7 Merge branch 'main' into patch-4 2023-08-27 19:57:34 +00:00
Aman Karmani
f319b0bc67 rename var and reformat 2023-08-27 19:55:11 +00:00
Maxime
7fd662dd89 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:43 +02:00
Maxime
9e699683d7 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:37 +02:00
mhenrichsen
35130711d6 Feat(cfg): Add code-llama configs for all sizes (#479)
* configs for all sizes

* update tokenizer type

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:20:17 +09:00
mhenrichsen
3fc9006298 Feat(deepspeed): Add zero2 config (#476)
* zero2 config

* config added

* linting

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:10:33 +09:00
NanoCode012
ad8be435ad Feat(doc): Update eval_steps doc (#487) 2023-08-27 10:09:09 +09:00
Charles O. Goddard
fe4d6baf92 Add example Llama 2 ReLoRA config (#471)
* Add example Llama 2 ReLoRA config

* Use adamw_bnb_8bit in example relora config
2023-08-27 10:08:34 +09:00
Aman Gupta Karmani
f31301063d Merge pull request #486 from OpenAccess-AI-Collective/adam-bnb-simpler
let transformers handle adamw_bnb_8bit
2023-08-26 20:44:19 -04:00
Aman Karmani
868530c39c let transformers handle adamw_bnb_8bit 2023-08-26 21:40:12 +00:00
Maxime
d03887fad5 ignore: address pr review 2023-08-26 22:45:45 +02:00
Maxime
17605b85d8 fix: inference did not move the model to the correct device (#483) 2023-08-26 16:40:56 -04:00
Maxime
a184549e4c ignore: linter 2023-08-26 22:36:14 +02:00
Maxime
f311df9462 fix: finetune model inference needs the dtype fix to work with flash-attn 2023-08-26 22:34:11 +02:00
Maxime
c500d02517 Fix missing 'packaging' wheel (#482) 2023-08-26 12:02:15 -04:00
Wing Lian
31f3e71764 fix checkpints on multigpu (#481) 2023-08-26 12:00:03 -04:00
Aman Gupta Karmani
56c4a94caf Merge pull request #484 from OpenAccess-AI-Collective/reqs
allow newer deps in requirements.txt
2023-08-26 11:13:41 -04:00
Aman Karmani
c29117a0d7 allow newer deps 2023-08-26 15:06:05 +00:00
Wing Lian
0b7ba57ec4 fix types w lora (#478) 2023-08-25 02:03:24 -04:00
NanoCode012
71bd06243c Fix(tokenizer): Fix condition to add pad token (#477)
* Fix(tokenizer): Fix condition to add pad token

* chore: fix lint
2023-08-25 14:30:50 +09:00
Wing Lian
cb9797ef5a improve llama pad token handling (#475)
* improve llama pad token handling

* tweak logic to not clobber
2023-08-24 13:20:35 -04:00
Charles O. Goddard
bde3c5a478 ReLoRA implementation (with quantization) (#322)
* Experimental ReLoRA (+qlora) implementation

* Add CPU offload

* Remove local config

* Fix saving logic

* Remove redundant assert

* Fix logic errors

* Move ReLoRA into its own trainer class with a method override to create the proper scheduler

* Formatting & typing fixes

* Use safe_serialization

* Don't allow fsdp/deepspeed with ReLoRA

* Fix cpu-offload logic, enable multi gpu

* Document parameters and add comment

* Fix merge issue

* Smooth over some sharp edges

* Implement resume from checkpoint for relora

* Address review comments

* Fix saving logic

* Add necessary metadata to safetensors

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-23 23:07:18 -04:00
NanoCode012
55c23c7bcb Fix(doc): Clarify config (#466) 2023-08-23 11:56:01 -04:00
Wing Lian
c69faee7a7 workaround so training doesn't hang when packed dataloader batches aren't even (#461)
* workaround so training doesn't hang when packed dataloader batches aren't even

* don't bother labeling anything in the no-op data
2023-08-23 10:39:11 -04:00
22 changed files with 1421 additions and 124 deletions

View File

@@ -493,6 +493,12 @@ lora_modules_to_save:
lora_out_dir:
lora_fan_in_fan_out: false
# ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # number of steps per ReLoRA restart
relora_warmup_steps: # number of per-restart warmup steps
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name
@@ -515,7 +521,7 @@ lr_quadratic_warmup:
logging_steps:
save_strategy: # set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch
eval_steps:
eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:
@@ -626,7 +632,7 @@ strict:
Run
```bash
accelerate launch scripts/finetune.py configs/your_config.yml
accelerate launch scripts/finetune.py your_config.yml
```
#### Multi-GPU

46
deepspeed/zero2.json Normal file
View File

@@ -0,0 +1,46 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,22 @@
# Overview
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
```

View File

@@ -57,7 +57,7 @@ weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "[PAD]"
pad_token: "<pad>"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,73 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./relora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
relora_steps: 150
relora_warmup_steps: 10
relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,12 +1,14 @@
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
evaluate
fire
PyYAML==6.0
PyYAML>=6.0
datasets
flash-attn==2.0.8
flash-attn>=2.0.8
sentencepiece
wandb
einops
@@ -15,7 +17,7 @@ optimum
hf_transfer
colorama
numba
numpy==1.24.4
numpy>=1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -82,6 +82,8 @@ def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
@@ -242,6 +244,21 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
@@ -273,20 +290,6 @@ def train(
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
@@ -301,6 +304,13 @@ def train(
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
@@ -308,6 +318,7 @@ def train(
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

View File

@@ -0,0 +1,393 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
self.resume_from_checkpoint = cfg.resume_from_checkpoint
if not os.path.exists(self.last_full_model):
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_train_begin(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.resume_from_checkpoint:
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
if not os.path.exists(weight_path):
LOG.warning(
"Resuming ReLoRA from checkpoint, but no full-weight save found"
)
else:
LOG.info(f"Loading adjusted base weights from {weight_path}")
load_weight_checkpoint(model, weight_path)
return control
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"relora",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(optimizer)
if self.quantized:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantized:
if is_main_process() and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
os.makedirs(checkpoint_folder, exist_ok=True)
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
try:
os.symlink(new_path, path)
except OSError:
# probably on windows without permission to symlink
pass
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
def on_train_end(
self,
args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.quantized:
# perform final merge and save
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
args.output_dir,
reinit=False,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
# no need to save if unquantized, as finetune.py will call merge_and_unload()
return control
class ReLoRAScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
with open(index_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data["weight_map"]
return {(module_name + ".weight"): model_name for module_name in module_names}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@ layer.lora_A[adapter].weight.detach().to(device),
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
for key in key_list:
try:
# pylint: disable=protected-access
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
modules[key] = target
return modules
def update_weights(
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
# PEFT's innards or maintaining a duplicate of that codepath, this is good
# enough for now.
target.weight.quant_state = None
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
else:
target.weight.data = new_weight.to(device)
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
cpu_offload: bool = False,
actually_save: bool = True,
):
modules = find_lora_modules(model)
if not quantized:
for module_name, target in modules.items():
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, modules.keys())
out_shard_paths = {}
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
orig_weight = in_tensors[key]
old_dev = target.weight.device
math_dev = "cpu" if cpu_offload else old_dev
delta_weight = lora_delta_weight(target, math_dev)
new_weight = orig_weight.to(math_dev) + delta_weight
del delta_weight
if actually_save:
out_tensors[key] = new_weight.half().cpu()
update_weights(target, new_weight, reinit=reinit, device=old_dev)
if actually_save:
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
for module_name in in_tensors:
if module_name not in out_tensors:
out_tensors[module_name] = in_tensors[module_name].half()
out_shard_paths[module_name] = out_shard_name
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
del in_tensors
del out_tensors
torch.cuda.empty_cache()
if actually_save and len(unique_shards) > 1:
with open(
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
) as file:
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
modules = find_lora_modules(model)
shard_paths = sharded_paths(checkpoint_path, modules.keys())
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
new_weight = tensors[key]
update_weights(
target, new_weight, reinit=False, device=target.weight.device
)

View File

@@ -13,7 +13,7 @@ from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec

View File

@@ -1,9 +1,19 @@
"""Callbacks for Trainer class"""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
TrainerCallback,
TrainerControl,
@@ -13,8 +23,19 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -33,7 +54,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
return control
@@ -94,3 +117,192 @@ class GPUStatsCallback(
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
target = f"{example['output']}{tokenizer.eos_token}"
tokenized_source = tokenizer(
source,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
tokenized_target = tokenizer(
target,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids"
]
return {
"input_ids": input_ids,
"labels": labels,
"subject": example["subject"],
}
with zero_first(is_main_process()):
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [
abcd_idx.index(label) if label in abcd_idx else -1
for label in labels.tolist()
]
loss_bench += loss.item()
# Extract results by subject.
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
barrier()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
loss_bench_ranks = gather_scalar_from_all_ranks(
lambda: loss_bench, get_world_size()
)
len_data_loader_ranks = gather_scalar_from_all_ranks(
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {"bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
return BenchEvalCallback

View File

@@ -126,6 +126,19 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if cfg.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."

View File

@@ -54,9 +54,10 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,

View File

@@ -243,6 +243,18 @@ class MultipackDistributedDataloader:
len_remaining -= 1
if not len_remaining:
return
# yield a no-op for cases where we don't have any data left to pack
for i in range(0, len_remaining):
yield self.collate_fn(
[
{
"input_ids": [0],
"labels": [-100],
"attention_mask": [True],
"position_ids": [0],
}
]
)
def _len_est(self):
lengths_sum = np.sum(self.lengths)

View File

@@ -1,8 +1,10 @@
"""
utility helpers for distributed checks
"""
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
from accelerate import Accelerator
@@ -43,6 +45,10 @@ def is_main_process():
return dist.get_rank() == 0
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_first(is_main):
"""
@@ -53,3 +59,35 @@ def zero_first(is_main):
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None

View File

@@ -11,7 +11,6 @@ import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft.tuners.lora import LoraLayer
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -22,7 +21,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
@@ -55,11 +54,18 @@ def load_tokenizer(cfg):
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
"LlamaTokenizerFast",
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
@@ -342,6 +348,15 @@ def load_model(
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -350,6 +365,18 @@ def load_model(
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
@@ -494,22 +521,6 @@ def load_lora(model, cfg):
else:
model = get_peft_model(model, lora_config)
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(cfg.torch_dtype)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(cfg.torch_dtype)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module = module.to(cfg.torch_dtype)
model.print_trainable_parameters()
return model, lora_config

View File

@@ -10,31 +10,30 @@ from functools import partial
from pathlib import Path
from typing import Optional, Union
import bitsandbytes as bnb
import numpy as np
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import (
SequentialDistributedSampler,
get_parameter_names,
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl")
@@ -127,6 +126,35 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
@@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer):
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -265,6 +322,39 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
def add_position_ids(sample):
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
return sample
@@ -484,6 +574,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch"
)
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -517,6 +612,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
**training_arguments_kwargs,
)
@@ -526,69 +623,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx")
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": training_args.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
if cfg.lr_scheduler == "one_cycle":
lr_scheduler_kwargs = (
cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
)
lr_scheduler = OneCycleLR(
optimizer,
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
lr_scheduler = InterpolatingLogScheduler(
optimizer,
cfg.warmup_steps,
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
)
else:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -633,11 +674,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
num_proc=32,
)
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else AxolotlTrainer
)
trainer_cls = AxolotlTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
trainer_cls = OneCycleLRSchedulerTrainer
elif cfg.relora_steps:
trainer_cls = ReLoRATrainer
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
@@ -648,8 +689,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer