Compare commits

..

56 Commits

Author SHA1 Message Date
Wing Lian
64af21bcb2 set env vars trainer needs for FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-11 08:46:26 -04:00
Wing Lian
6b5cf8b5ea optimize length reducer from 9m -> <5sec 2023-08-11 08:30:30 -04:00
Wing Lian
79500f358a need to pass total num tokens to trainer too 2023-08-10 19:08:23 -04:00
Wing Lian
7e977a9b68 optimization if total_num_tokens is already known 2023-08-10 19:02:28 -04:00
Wing Lian
ac4b700daa optimization if total_num_tokens is already known 2023-08-10 19:01:17 -04:00
Wing Lian
2565c2f259 async batching for multipack 2023-08-10 18:28:15 -04:00
Wing Lian
a07f432d9c calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier 2023-08-10 17:16:01 -04:00
Wing Lian
57d9bf711c let's not cleanup the cached datasets 2023-08-08 21:27:55 -04:00
Wing Lian
26983a1974 fix sampler to prevent overfit w new epochs 2023-08-08 15:34:18 -04:00
Wing Lian
1b8747e319 use custom distributed checks 2023-08-08 13:35:04 -04:00
Wing Lian
035b3c760c add numba to requirements. 2023-08-08 10:55:29 -04:00
Wing Lian
17abbd59e1 previous accelerate is still most performant 2023-08-08 09:46:01 -04:00
Wing Lian
6ec76ddb4c fix steps calculation 2023-08-08 05:13:21 -04:00
Wing Lian
21d307b15b fix counts by accounting for num devices 2023-08-08 04:13:10 -04:00
Wing Lian
58e9dee204 fixes and go back to distributed sampler since batch sampler won't work 2023-08-08 03:49:29 -04:00
Wing Lian
4f7c04bae0 more fixes and optimizations 2023-08-08 03:16:00 -04:00
Wing Lian
1162b93b6b filter w multiple cpus 2023-08-08 00:50:56 -04:00
Wing Lian
21f445d763 more packing and dataset optimizations and fixes 2023-08-08 00:45:24 -04:00
Wing Lian
229b9165aa fix test and pylint checks 2023-08-07 09:38:05 -04:00
Wing Lian
394a65f11f add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test 2023-08-07 09:38:04 -04:00
Wing Lian
c70dae63cc add chatml 2023-08-07 09:38:04 -04:00
Wing Lian
7712955b35 fix chatml system prompt for openorca, legacy tokenizer opts 2023-08-07 09:38:04 -04:00
Wing Lian
f93f0017cd fix flash-attn, xformers, packing, support chatml 2023-08-07 09:38:04 -04:00
Wing Lian
0b01da0713 properly calculate max len 2023-08-07 09:38:04 -04:00
Wing Lian
b2f7bc7ccd use cumulative seq len with var len flash attn v2 w packing 2023-08-07 09:38:04 -04:00
Wing Lian
b8905e2a91 sample_packing_seq_len_multiplier config 2023-08-07 09:38:04 -04:00
Wing Lian
7e1edc662a make sure the chunk size is an int 2023-08-07 09:38:04 -04:00
Wing Lian
98c9bc69de seq_len_multiple for packing 2023-08-07 09:38:04 -04:00
Wing Lian
8378335dc9 limit packing to sequences of max seq len 2023-08-07 09:38:04 -04:00
Wing Lian
bdd34c7400 weighted CEL fixes 2023-08-07 09:38:04 -04:00
Wing Lian
c6cc54c7d9 weighted CE losses 2023-08-07 09:38:04 -04:00
Wing Lian
83f7362480 don't split batches when packing 2023-08-07 09:38:04 -04:00
Wing Lian
958d423e7c only process eval dataset for packing if not None 2023-08-07 09:38:04 -04:00
Wing Lian
e74eab6e73 add a test for the mask expansion for sequence packing 2023-08-07 09:38:04 -04:00
Wing Lian
487abfc769 pass sample packing efficiency to training args 2023-08-07 09:38:04 -04:00
Wing Lian
2bee646e85 fix step calc for packing 2023-08-07 09:38:04 -04:00
Wing Lian
945f2e5029 better handling so that all devices have the same dataloader len 2023-08-07 09:38:04 -04:00
Wing Lian
daed942fe9 fix rounding of len of batches to int 2023-08-07 09:38:04 -04:00
Wing Lian
df3eb645da better handling of variance in multipack dataloader length and trainer hanging when it runs out of data 2023-08-07 09:38:04 -04:00
Wing Lian
32fed7039d optimized expand mask fn 2023-08-07 09:38:04 -04:00
Wing Lian
7d7b5ebd71 more fixes for 4k and optimizations 2023-08-07 09:38:03 -04:00
Wing Lian
4b7ad9927f validation for sample packing and doc 2023-08-07 09:38:03 -04:00
Wing Lian
fedcf5a089 Update src/axolotl/utils/dataloader.py 2023-08-07 09:38:03 -04:00
Wing Lian
2f2974196d fix for position_ids w packing 2023-08-07 09:38:03 -04:00
Wing Lian
2e295c9f94 use accelerator prepare for dataloader 2023-08-07 09:38:03 -04:00
Wing Lian
4ab9ab79fd use distributed sampler, avoid accelerate prepare 2023-08-07 09:38:03 -04:00
Wing Lian
b02484a83e more fixes for sample packing 2023-08-07 09:38:03 -04:00
Wing Lian
58045f0816 more fixes, position_ids seems broken 2023-08-07 09:38:03 -04:00
Wing Lian
66774011c4 est total tokens, fix field loop 2023-08-07 09:38:03 -04:00
Wing Lian
41d4992029 more fixes for dataloader integration 2023-08-07 09:38:03 -04:00
Wing Lian
762f1b08db add position_ids back 2023-08-07 09:38:03 -04:00
Wing Lian
3aba4c5d7c use multi pack dataloader w random sampler 2023-08-07 09:38:03 -04:00
Wing Lian
ffd96839cf don't move masks to cpu 2023-08-07 09:38:03 -04:00
Wing Lian
ef9bf7ad73 fix expand mask for multiple batch items, make sure we pad position_ids 2023-08-07 09:38:03 -04:00
Wing Lian
4964b0d345 set position ids and use block diagonal attn mask 2023-08-07 09:38:03 -04:00
Wing Lian
36b0e30a9d fix attetion mask with packing 2023-08-07 09:38:03 -04:00
35 changed files with 160 additions and 623 deletions

View File

@@ -375,14 +375,10 @@ dataset_shard_idx:
sequence_len: 2048 sequence_len: 2048
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # soon to be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # use efficient multi-packing with block diagonal attention and per sequence position_ids
sample_packing: sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
@@ -408,12 +404,11 @@ lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode:
wandb_project: # your wandb project name wandb_project:
wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # set the name of your wandb run wandb_run_id:
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # 'checkpoint'
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -428,16 +423,13 @@ learning_rate: 0.00003
logging_steps: logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
save_total_limit:
# save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# group similarly sized data to minimize padding # don't use this, leads to wonky training (according to someone on the internet)
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -483,10 +475,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -518,9 +506,6 @@ torchdistx_path:
# Set padding for data collator to 'longest' # Set padding for data collator to 'longest'
collator_pad_to_longest: collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode # Debug mode
debug: debug:
@@ -540,14 +525,7 @@ Run
accelerate launch scripts/finetune.py configs/your_config.yml accelerate launch scripts/finetune.py configs/your_config.yml
``` ```
#### Multi-GPU #### Multi-GPU Config
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -562,18 +540,6 @@ fsdp_config:
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command - llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -36,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -38,7 +38,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -33,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0001 learning_rate: 0.0001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -22,7 +22,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4 wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -18,7 +18,6 @@ lora_dropout:
lora_target_modules: lora_target_modules:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -26,7 +26,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -39,7 +38,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false

View File

@@ -27,7 +27,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -40,7 +39,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -28,7 +28,6 @@ lora_target_modules:
- o_proj - o_proj
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -35,7 +34,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -17,7 +17,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -21,7 +21,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- mlp_down - mlp_down
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -37,7 +37,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -1,6 +1,6 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict addict
fire fire
@@ -21,4 +21,3 @@ evaluate==0.4.0
rouge-score==0.1.2 rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml

View File

@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import barrier, is_main_process
@@ -269,13 +268,16 @@ def train(
LOG.info("Finished preparing dataset. Exiting...") LOG.info("Finished preparing dataset. Exiting...")
return return
log_gpu_memory_usage(LOG, "baseline", cfg.device)
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") LOG.info("loading model and peft_config...")
model, peft_config = load_model(cfg, tokenizer) model, peft_config = load_model(
cfg.base_model,
safe_serialization = cfg.save_safetensors is True cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
@@ -284,11 +286,7 @@ def train(
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info("saving merged model") LOG.info("saving merged model")
model.save_pretrained( model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if cfg.inference: if cfg.inference:
@@ -303,7 +301,7 @@ def train(
return return
if "shard" in kwargs: if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
return return
trainer = setup_trainer( trainer = setup_trainer(
@@ -327,7 +325,7 @@ def train(
def terminate_handler(_, __, model): def terminate_handler(_, __, model):
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
sys.exit(0) sys.exit(0)
signal.signal( signal.signal(
@@ -371,13 +369,7 @@ def train(
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
if cfg.adapter == "lora" and cfg.relora_steps:
model = model.merge_and_unload()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset): class TokenizedPromptDataset(IterableDataset):
""" """
Dataset that returns tokenized prompts from a stream of text files. Iterable dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
**kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
super().__init__(self.process(dataset).data, **kwargs) self.dataset = dataset
def process(self, dataset): def __iter__(self):
features = dataset.features.keys() features = self.dataset.features.keys()
num_proc = min(64, os.cpu_count()) num_proc = os.cpu_count()
return dataset.map( return iter(
self.prompt_tokenizer.tokenize_prompt, self.dataset.map(
num_proc=num_proc, self.prompt_tokenizer.tokenize_prompt,
remove_columns=features, num_proc=num_proc,
remove_columns=features,
)
) )

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple
import torch import torch
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try: try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
@@ -92,8 +91,7 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif position_ids.shape[0] == 1: else:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze() cu_q_lens = cu_q_lens.squeeze()
@@ -102,36 +100,6 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return ( return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")), self.o_proj(rearrange(output, "b s h d -> b s (h d)")),

View File

@@ -1,302 +0,0 @@
# pylint: skip-file
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
elif key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantised,
)
reset_optimizer(optimizer)
if self.quantised:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantised and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
shutil.move(path, checkpoint_folder)
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
class ReLoRAScheduler(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
else:
return original * scale
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
data = json.load(open(index_path, "r"))
return data["weight_map"]
return {key + ".weight": model_name for key in keys}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
layer, peft.tuners.lora.Linear4bit
):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
else:
return layer.get_delta_weight()
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
cpu_offload: bool = False,
):
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
if not quantized:
for key in key_list:
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, key_list)
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for key in key_list:
if (key + ".weight") not in shard_paths or shard_paths[
key + ".weight"
] != shard_path:
continue
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
orig_weight = in_tensors[key + ".weight"]
old_dev = target.weight.device
math_dev = "cpu" if cpu_offload else old_dev
update = lora_delta_weight(target).detach().to(math_dev)
new_weight = orig_weight.to(math_dev) + update
out_tensors[key + ".weight"] = new_weight
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
target.weight = (
bnb.nn.Params4bit(
new_weight,
requires_grad=False,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
)
.cuda(None)
.to(old_dev)
)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = (
bnb.nn.Int8Params(new_weight, requires_grad=False)
.cuda(None)
.to(old_dev)
)
else:
target.weight.data = new_weight.to(old_dev)
for key in in_tensors:
if key not in out_tensors:
out_tensors[key] = in_tensors[key]
del in_tensors
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn)
del out_tensors
torch.cuda.empty_cache()
if len(unique_shards) > 1:
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)

View File

@@ -95,9 +95,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "SYSTEM: {system}\n" self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value: if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = ( self.turn_no_input_format = (

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE from axolotl.prompters import IGNORE_TOKEN_ID
@dataclass @dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801 conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2]
if sentence["value"]: if sentence["value"]:
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
yield conv yield conv

View File

@@ -271,11 +271,6 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
@@ -332,7 +327,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_prompt():

View File

@@ -1,23 +0,0 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
# NB torch.cuda.memory_usage returns zero so we use lower level api
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
log.info(
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
)

View File

@@ -1,6 +1,5 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
import logging
import os import os
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
@@ -12,10 +11,6 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl.callbacks")
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter""" """Callback to save the PEFT adapter"""
@@ -33,9 +28,7 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
) )
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained( kwargs["model"].save_pretrained(peft_model_path)
peft_model_path, save_safetensors=args.save_safetensors
)
return control return control
@@ -74,25 +67,3 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False control.should_save = False
return control return control
class PrintGPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,19 +1,14 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import itertools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import List, Tuple, Union
import torch import torch
from datasets import ( from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -270,12 +265,20 @@ def load_tokenized_prepared_datasets(
raise ValueError( raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}" f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
LOG.info("merging datasets") LOG.info("tokenizing, merging, and shuffling master dataset")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1: samples: List[int] = []
LOG.info("shuffle merged datasets") chunk_size = 1000
dataset = dataset.shuffle(seed=seed) for d in datasets:
d_iter = iter(d)
while True:
chunk = list(itertools.islice(d_iter, chunk_size))
if not chunk:
break
samples.extend(chunk)
LOG.info("shuffle")
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)

View File

@@ -3,7 +3,9 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
from typing import Any, Callable, List, Union import queue
import threading
from typing import Any, Callable, List, Optional, Union
import numba import numba
import numpy as np import numpy as np
@@ -78,7 +80,6 @@ def allocate(
s = 0 s = 0
start_index = 0 start_index = 0
result = [] result = []
result_totseqs = []
while True: while True:
# binary search [left, right) # binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank # add local rank
result.append(batch[rank]) result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs) yield batch[rank], tot_seqs, s, len(result) * c * n
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n): def chunk(iterable, n):
@@ -149,15 +148,14 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
total_num_tokens: Optional[int] = None,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
self.lengths = ( lengths_series = (
dataset.data.column("position_ids") dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
) )
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0 assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier assert batch_size >= sample_packing_seq_len_multiplier
@@ -172,11 +170,17 @@ class MultipackDistributedDataloader:
self.rank = 0 self.rank = 0
# statistics # statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -188,65 +192,83 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate( alloc_iter = iter(
lengths=lengths, allocate(
lengths_cumsum=lengths_cumsum, lengths=lengths,
rank=self.rank, lengths_cumsum=lengths_cumsum,
# c=self.batch_max_length, rank=self.rank,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier, # c=self.batch_max_length,
n=self.num_replicas, c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
) )
batches = [[indices[b_idx] for b_idx in batch] for batch in batches] for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
# statistics def _generate_batches_thread(self):
if set_stats: try:
self.eff_total_used += total_used self.generate_batches(set_stats=True)
self.eff_total_slots += total_slots except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
return batches, totseqs self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"): if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1 new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch) self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})") LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) # Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
for batches in chunk( while True:
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier batch = self.batch_queue.get()
): if batch is None: # Sentinel value received, stop iteration
break
chunked_data = [] chunked_data = []
attn_mask_cum_idx = 0 attn_mask_cum_idx = 0
for batch in batches: concatenated = {}
concatenated = {} batched_data = [self.dataset[batch_idx] for batch_idx in batch]
batched_data = [self.dataset[batch_idx] for batch_idx in batch] for feature in features:
for feature in features: if feature == "attention_mask":
if feature == "attention_mask": arrays = [
arrays = [ (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
(attn_mask_cum_idx + idx + 1) * np.array(item[feature]) for idx, item in enumerate(batched_data)
for idx, item in enumerate(batched_data) if feature in item
if feature in item ]
] attn_mask_cum_idx += len(batched_data)
attn_mask_cum_idx += len(batched_data) concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) else:
else: arrays = [
arrays = [ np.array(item[feature])
np.array(item[feature]) for item in batched_data
for item in batched_data if feature in item
if feature in item ]
] concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) chunked_data.append(concatenated)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data) yield self.collate_fn(chunked_data)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return break
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self): def _len_est(self):
lengths_sum = np.sum(self.lengths) if not self.total_num_tokens:
lengths_sum_per_device = lengths_sum // self.device_count self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -22,7 +22,6 @@ from transformers import ( # noqa: F401
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -84,22 +83,19 @@ def load_tokenizer(
def load_model( def load_model(
cfg, tokenizer base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] ):
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model for a given configuration and tokenizer. Load a model from a base model and a model type.
""" """
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
"llama" in base_model "llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower()) or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model or cfg.is_llama_derived_model is True
) )
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -235,9 +231,7 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained( config = LlamaConfig.from_pretrained(base_model_config)
base_model_config, rope_scaling=cfg.rope_scaling
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
@@ -342,9 +336,6 @@ def load_model(
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -381,9 +372,6 @@ def load_model(
module.scales = module.scales.half() module.scales = module.scales.half()
module.bias = module.bias.half() module.bias = module.bias.half()
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after adapters", model.device)
if ( if (
torch.cuda.device_count() > 1 torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
@@ -21,9 +20,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
@@ -125,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1, default=1,
metadata={"help": "the multiplier for the max len for packed sequences"}, metadata={"help": "the multiplier for the max len for packed sequences"},
) )
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "the total number of tokens in the train dataset"},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -185,6 +186,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -207,6 +209,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -285,16 +288,13 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
@@ -302,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
( (
math.floor( math.floor(
0.99 0.99
* cfg.total_num_tokens * total_num_tokens
/ cfg.sample_packing_eff_est / cfg.sample_packing_eff_est
/ cfg.sequence_len / 2048
// cfg.batch_size // cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1)) // int(os.environ.get("WORLD_SIZE", 1))
) )
@@ -313,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.info( LOG.info(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
) )
else: else:
sampler = RandomSampler(train_dataset) sampler = RandomSampler(train_dataset)
@@ -345,7 +345,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info( LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
) )
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -484,7 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine", else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs, **training_arguments_kwargs,
) )
@@ -556,19 +556,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = [] callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
if cfg.relora_steps:
relora_steps = int(cfg.relora_steps)
relora_warmup_steps = int(cfg.relora_warmup_steps)
callbacks.append(ReLoRACallback(cfg))
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
trainer_kwargs["optimizers"] = (
optimizer,
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
)
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(

View File

@@ -61,9 +61,6 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -113,13 +110,6 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention: if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError( raise ValueError(

View File

@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0: elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: