Compare commits

..

28 Commits

Author SHA1 Message Date
Charles Goddard
1afbd8af2d Fix logic errors
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-12 20:31:59 -04:00
Charles Goddard
b4f2eea2ed Remove redundant assert 2023-08-12 20:31:59 -04:00
Charles Goddard
bbf88b02c1 Fix saving logic 2023-08-12 20:31:59 -04:00
Charles Goddard
64a8e04430 Remove local config 2023-08-12 20:31:59 -04:00
Charles Goddard
c8f7213bc6 Add CPU offload 2023-08-12 20:31:59 -04:00
Charles Goddard
b57238ecec Experimental ReLoRA (+qlora) implementation 2023-08-12 20:31:57 -04:00
Wing Lian
918f1b0dfb revert previous change and build ax images w docker on gpu (#371) 2023-08-12 20:23:00 -04:00
Wing Lian
c3fde36ada attempt to run non-base docker builds on regular cpu hosts (#369) 2023-08-12 19:07:38 -04:00
Wing Lian
2bb0b78975 Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme
2023-08-12 15:14:56 -04:00
NanoCode012
a276c9c88d Fix(save): Save as safetensors (#363) 2023-08-13 01:22:52 +09:00
Morgan McGuire
7019509daa Add wandb_entity to wandb options, update example configs, update README (#361)
* Update wandb_entity and add wandb descriptions

* add wandb to config section

* remove trailing whitespace for pre-commit hook

* remove trailing whitespace for pre-commit hook

---------

Co-authored-by: Morgan McGuire <morganmcguire@Morgans-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-12 12:17:11 -04:00
NanoCode012
96bd6ae1c4 Fix(model loading): Warn when model revision is passed to gptq (#364)
* fix(model loading): warn when model revision is passed to gptq

* chore: improve message
2023-08-13 01:16:59 +09:00
NanoCode012
e37d9358e6 Fix(message): Improve error message for bad format (#365) 2023-08-13 01:16:18 +09:00
NanoCode012
b5212068ac Feat: Add rope scaling (#343)
* Feat: Add rope scaling

* fix: move rope config
2023-08-13 00:50:15 +09:00
NanoCode012
289d5c403d feat(merge): save tokenizer on merge (#362) 2023-08-13 00:18:10 +09:00
Aman Gupta Karmani
35c8b90306 Merge pull request #355 from tmm1/bitsandbytes-fixes
bump to latest bitsandbytes release with major bug fixes
2023-08-11 15:15:38 -07:00
NanoCode012
fae6ed8092 Update README.md on pretraining_dataset (#360)
* Update README.md on pretraining_dataset

* Fix message
2023-08-11 12:17:07 +09:00
NanoCode012
94d03c8402 Clarify pre-tokenize before multigpu (#359) 2023-08-11 11:27:42 +09:00
Aman Gupta Karmani
11ddccb80f Merge pull request #356 from tmm1/load_model-args
simplify `load_model` signature
2023-08-09 18:24:34 -07:00
Aman Gupta Karmani
964312199e Merge pull request #354 from tmm1/gpu-util
GPU memory usage logging
2023-08-09 15:44:18 -07:00
Aman Karmani
718102271f simplify load_model signature 2023-08-09 22:36:02 +00:00
Aman Gupta Karmani
f5c11f8262 Merge pull request #350 from tmm1/group-len-false-examples
set `group_by_length` to false in all examples
2023-08-09 14:48:48 -07:00
Aman Karmani
fce40aab23 bump to latest bitsandbytes release with major bug fixes 2023-08-09 21:47:11 +00:00
Aman Karmani
9c314101d5 use newer pynvml package 2023-08-09 21:06:28 +00:00
Aman Karmani
e303d64728 log GPU memory usage 2023-08-09 18:26:28 +00:00
Aman Karmani
b4d1d22782 note pattern when using groups 2023-08-07 16:18:42 -07:00
Aman Karmani
9f99104038 update comment for group_by_length 2023-08-07 01:04:56 -07:00
Aman Karmani
36fefcf94b set group_by_length to false in examples 2023-08-06 23:59:09 -07:00
35 changed files with 623 additions and 160 deletions

View File

@@ -375,10 +375,14 @@ dataset_shard_idx:
sequence_len: 2048 sequence_len: 2048
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# soon to be DEPRECATED # FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids # use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
@@ -404,11 +408,12 @@ lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: wandb_project: # your wandb project name
wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id: # set the name of your wandb run
wandb_log_model: # 'checkpoint' wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -423,13 +428,16 @@ learning_rate: 0.00003
logging_steps: logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
save_total_limit:
# save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet) # group similarly sized data to minimize padding
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -475,6 +483,10 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -506,6 +518,9 @@ torchdistx_path:
# Set padding for data collator to 'longest' # Set padding for data collator to 'longest'
collator_pad_to_longest: collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode # Debug mode
debug: debug:
@@ -525,7 +540,14 @@ Run
accelerate launch scripts/finetune.py configs/your_config.yml accelerate launch scripts/finetune.py configs/your_config.yml
``` ```
#### Multi-GPU Config #### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -540,6 +562,18 @@ fsdp_config:
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command - llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:

View File

@@ -23,6 +23,7 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -35,7 +36,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: true group_by_length: false
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -24,6 +24,7 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -38,6 +38,7 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -24,6 +24,7 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,6 +20,7 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -32,7 +33,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0001 learning_rate: 0.0001
train_on_inputs: false train_on_inputs: false
group_by_length: true group_by_length: false
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -22,6 +22,7 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4 wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -18,6 +18,7 @@ lora_dropout:
lora_target_modules: lora_target_modules:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -26,6 +26,7 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -38,7 +39,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: true group_by_length: false
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false

View File

@@ -27,6 +27,7 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -39,7 +40,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: true group_by_length: false
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false

View File

@@ -20,6 +20,7 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,6 +22,7 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -28,6 +28,7 @@ lora_target_modules:
- o_proj - o_proj
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,6 +22,7 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -34,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: true group_by_length: false
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -23,6 +23,7 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -17,6 +17,7 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -21,6 +21,7 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,6 +20,7 @@ lora_target_modules:
- mlp_down - mlp_down
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -37,6 +37,7 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -1,6 +1,6 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict addict
fire fire
@@ -21,3 +21,4 @@ evaluate==0.4.0
rouge-score==0.1.2 rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml

View File

@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import barrier, is_main_process
@@ -268,16 +269,13 @@ def train(
LOG.info("Finished preparing dataset. Exiting...") LOG.info("Finished preparing dataset. Exiting...")
return return
log_gpu_memory_usage(LOG, "baseline", cfg.device)
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and peft_config...") LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model( model, peft_config = load_model(cfg, tokenizer)
cfg.base_model,
cfg.base_model_config, safe_serialization = cfg.save_safetensors is True
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
@@ -286,7 +284,11 @@ def train(
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info("saving merged model") LOG.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged")) model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if cfg.inference: if cfg.inference:
@@ -301,7 +303,7 @@ def train(
return return
if "shard" in kwargs: if "shard" in kwargs:
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return return
trainer = setup_trainer( trainer = setup_trainer(
@@ -325,7 +327,7 @@ def train(
def terminate_handler(_, __, model): def terminate_handler(_, __, model):
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0) sys.exit(0)
signal.signal( signal.signal(
@@ -369,7 +371,13 @@ def train(
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
if cfg.adapter == "lora" and cfg.relora_steps:
model = model.merge_and_unload()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import IterableDataset from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset): class TokenizedPromptDataset(Dataset):
""" """
Iterable dataset that returns tokenized prompts from a stream of text files. Dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
**kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset super().__init__(self.process(dataset).data, **kwargs)
def __iter__(self): def process(self, dataset):
features = self.dataset.features.keys() features = dataset.features.keys()
num_proc = os.cpu_count() num_proc = min(64, os.cpu_count())
return iter( return dataset.map(
self.dataset.map( self.prompt_tokenizer.tokenize_prompt,
self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc,
num_proc=num_proc, remove_columns=features,
remove_columns=features,
)
) )

View File

@@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch import torch
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try: try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
@@ -91,7 +92,8 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else: elif position_ids.shape[0] == 1:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze() cu_q_lens = cu_q_lens.squeeze()
@@ -100,6 +102,36 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return ( return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")), self.o_proj(rearrange(output, "b s h d -> b s (h d)")),

View File

@@ -0,0 +1,302 @@
# pylint: skip-file
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
elif key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantised,
)
reset_optimizer(optimizer)
if self.quantised:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantised and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
shutil.move(path, checkpoint_folder)
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
class ReLoRAScheduler(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
else:
return original * scale
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
data = json.load(open(index_path, "r"))
return data["weight_map"]
return {key + ".weight": model_name for key in keys}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
layer, peft.tuners.lora.Linear4bit
):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
else:
return layer.get_delta_weight()
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
cpu_offload: bool = False,
):
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
if not quantized:
for key in key_list:
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, key_list)
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for key in key_list:
if (key + ".weight") not in shard_paths or shard_paths[
key + ".weight"
] != shard_path:
continue
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
orig_weight = in_tensors[key + ".weight"]
old_dev = target.weight.device
math_dev = "cpu" if cpu_offload else old_dev
update = lora_delta_weight(target).detach().to(math_dev)
new_weight = orig_weight.to(math_dev) + update
out_tensors[key + ".weight"] = new_weight
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
target.weight = (
bnb.nn.Params4bit(
new_weight,
requires_grad=False,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
)
.cuda(None)
.to(old_dev)
)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = (
bnb.nn.Int8Params(new_weight, requires_grad=False)
.cuda(None)
.to(old_dev)
)
else:
target.weight.data = new_weight.to(old_dev)
for key in in_tensors:
if key not in out_tensors:
out_tensors[key] = in_tensors[key]
del in_tensors
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn)
del out_tensors
torch.cuda.empty_cache()
if len(unique_shards) > 1:
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)

View File

@@ -95,9 +95,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "User: {instruction}\n{input}\nAssistant:" self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "User: {instruction}\nAssistant:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "System: {system}\n" self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value: if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = ( self.turn_no_input_format = (

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass @dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801 conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2] assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]: if sentence["value"]:
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
yield conv yield conv

View File

@@ -271,6 +271,11 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
@@ -327,7 +332,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2] assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_prompt():

View File

@@ -0,0 +1,23 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
# NB torch.cuda.memory_usage returns zero so we use lower level api
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
log.info(
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
)

View File

@@ -1,5 +1,6 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
import logging
import os import os
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
@@ -11,6 +12,10 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl.callbacks")
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter""" """Callback to save the PEFT adapter"""
@@ -28,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
) )
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
return control return control
@@ -67,3 +74,25 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False control.should_save = False
return control return control
class PrintGPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,14 +1,19 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import itertools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import Tuple, Union
import torch import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import (
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -265,20 +270,12 @@ def load_tokenized_prepared_datasets(
raise ValueError( raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}" f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
LOG.info("tokenizing, merging, and shuffling master dataset") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
samples: List[int] = [] if len(datasets) > 1:
chunk_size = 1000 LOG.info("shuffle merged datasets")
for d in datasets: dataset = dataset.shuffle(seed=seed)
d_iter = iter(d)
while True:
chunk = list(itertools.islice(d_iter, chunk_size))
if not chunk:
break
samples.extend(chunk)
LOG.info("shuffle")
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)

View File

@@ -3,9 +3,7 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
import queue from typing import Any, Callable, List, Union
import threading
from typing import Any, Callable, List, Optional, Union
import numba import numba
import numpy as np import numpy as np
@@ -80,6 +78,7 @@ def allocate(
s = 0 s = 0
start_index = 0 start_index = 0
result = [] result = []
result_totseqs = []
while True: while True:
# binary search [left, right) # binary search [left, right)
@@ -105,8 +104,10 @@ def allocate(
# add local rank # add local rank
result.append(batch[rank]) result.append(batch[rank])
# add total seqs for all ranks
yield batch[rank], tot_seqs, s, len(result) * c * n result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n): def chunk(iterable, n):
@@ -148,14 +149,15 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
total_num_tokens: Optional[int] = None,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
lengths_series = ( self.lengths = (
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1) dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
) )
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0 assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier assert batch_size >= sample_packing_seq_len_multiplier
@@ -170,17 +172,11 @@ class MultipackDistributedDataloader:
self.rank = 0 self.rank = 0
# statistics # statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -192,83 +188,65 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
alloc_iter = iter( batches, totseqs, total_used, total_slots = allocate(
allocate( lengths=lengths,
lengths=lengths, lengths_cumsum=lengths_cumsum,
lengths_cumsum=lengths_cumsum, rank=self.rank,
rank=self.rank, # c=self.batch_max_length,
# c=self.batch_max_length, c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier, n=self.num_replicas,
n=self.num_replicas,
)
) )
for batch, tot_seqs, total_used, total_slots in alloc_iter: batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
def _generate_batches_thread(self): # statistics
try: if set_stats:
self.generate_batches(set_stats=True) self.eff_total_used += total_used
except Exception as e: self.eff_total_slots += total_slots
LOG.error(f"Error in batch generation thread: {e}")
self.batch_queue.put( return batches, totseqs
None
) # Signal the end of batch generation in case of error
def __iter__(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"): if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1 new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch) self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})") LOG.info(f"calling sampler.set_epoch({new_epoch})")
# Start the batch generation in a separate thread all_batches, _ = self.generate_batches(set_stats=True)
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
while True: for batches in chunk(
batch = self.batch_queue.get() all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
if batch is None: # Sentinel value received, stop iteration ):
break
chunked_data = [] chunked_data = []
attn_mask_cum_idx = 0 attn_mask_cum_idx = 0
concatenated = {} for batch in batches:
batched_data = [self.dataset[batch_idx] for batch_idx in batch] concatenated = {}
for feature in features: batched_data = [self.dataset[batch_idx] for batch_idx in batch]
if feature == "attention_mask": for feature in features:
arrays = [ if feature == "attention_mask":
(attn_mask_cum_idx + idx + 1) * np.array(item[feature]) arrays = [
for idx, item in enumerate(batched_data) (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
if feature in item for idx, item in enumerate(batched_data)
] if feature in item
attn_mask_cum_idx += len(batched_data) ]
concatenated[feature] = np.concatenate(arrays) attn_mask_cum_idx += len(batched_data)
else: concatenated[feature] = np.concatenate(arrays)
arrays = [ else:
np.array(item[feature]) arrays = [
for item in batched_data np.array(item[feature])
if feature in item for item in batched_data
] if feature in item
concatenated[feature] = np.concatenate(arrays) ]
chunked_data.append(concatenated) concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data) yield self.collate_fn(chunked_data)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
break return
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self): def _len_est(self):
if not self.total_num_tokens: lengths_sum = np.sum(self.lengths)
self.total_num_tokens = np.sum(self.lengths) lengths_sum_per_device = lengths_sum // self.device_count
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -83,19 +84,22 @@ def load_tokenizer(
def load_model( def load_model(
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" cfg, tokenizer
): ): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
adapter = cfg.adapter
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
"llama" in base_model "llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower()) or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model is True or cfg.is_llama_derived_model
) )
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -231,7 +235,9 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config) config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
@@ -336,6 +342,9 @@ def load_model(
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -372,6 +381,9 @@ def load_model(
module.scales = module.scales.half() module.scales = module.scales.half()
module.bias = module.bias.half() module.bias = module.bias.half()
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after adapters", model.device)
if ( if (
torch.cuda.device_count() > 1 torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
@@ -20,7 +21,9 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
@@ -122,10 +125,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1, default=1,
metadata={"help": "the multiplier for the max len for packed sequences"}, metadata={"help": "the multiplier for the max len for packed sequences"},
) )
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "the total number of tokens in the train dataset"},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -186,7 +185,6 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -209,7 +207,6 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -288,13 +285,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
@@ -302,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
( (
math.floor( math.floor(
0.99 0.99
* total_num_tokens * cfg.total_num_tokens
/ cfg.sample_packing_eff_est / cfg.sample_packing_eff_est
/ 2048 / cfg.sequence_len
// cfg.batch_size // cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1)) // int(os.environ.get("WORLD_SIZE", 1))
) )
@@ -313,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.info( LOG.info(
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
) )
else: else:
sampler = RandomSampler(train_dataset) sampler = RandomSampler(train_dataset)
@@ -345,6 +345,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info( LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
) )
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -483,8 +484,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine", else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1, sample_packing_seq_len_multiplier=cfg.micro_batch_size,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs, **training_arguments_kwargs,
) )
@@ -556,6 +556,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = [] callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
if cfg.relora_steps:
relora_steps = int(cfg.relora_steps)
relora_warmup_steps = int(cfg.relora_warmup_steps)
callbacks.append(ReLoRACallback(cfg))
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
trainer_kwargs["optimizers"] = (
optimizer,
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
)
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(

View File

@@ -61,6 +61,9 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -110,6 +113,13 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention: if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError( raise ValueError(

View File

@@ -9,6 +9,8 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0: elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: