Compare commits

..

56 Commits

Author SHA1 Message Date
Wing Lian
64af21bcb2 set env vars trainer needs for FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-11 08:46:26 -04:00
Wing Lian
6b5cf8b5ea optimize length reducer from 9m -> <5sec 2023-08-11 08:30:30 -04:00
Wing Lian
79500f358a need to pass total num tokens to trainer too 2023-08-10 19:08:23 -04:00
Wing Lian
7e977a9b68 optimization if total_num_tokens is already known 2023-08-10 19:02:28 -04:00
Wing Lian
ac4b700daa optimization if total_num_tokens is already known 2023-08-10 19:01:17 -04:00
Wing Lian
2565c2f259 async batching for multipack 2023-08-10 18:28:15 -04:00
Wing Lian
a07f432d9c calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier 2023-08-10 17:16:01 -04:00
Wing Lian
57d9bf711c let's not cleanup the cached datasets 2023-08-08 21:27:55 -04:00
Wing Lian
26983a1974 fix sampler to prevent overfit w new epochs 2023-08-08 15:34:18 -04:00
Wing Lian
1b8747e319 use custom distributed checks 2023-08-08 13:35:04 -04:00
Wing Lian
035b3c760c add numba to requirements. 2023-08-08 10:55:29 -04:00
Wing Lian
17abbd59e1 previous accelerate is still most performant 2023-08-08 09:46:01 -04:00
Wing Lian
6ec76ddb4c fix steps calculation 2023-08-08 05:13:21 -04:00
Wing Lian
21d307b15b fix counts by accounting for num devices 2023-08-08 04:13:10 -04:00
Wing Lian
58e9dee204 fixes and go back to distributed sampler since batch sampler won't work 2023-08-08 03:49:29 -04:00
Wing Lian
4f7c04bae0 more fixes and optimizations 2023-08-08 03:16:00 -04:00
Wing Lian
1162b93b6b filter w multiple cpus 2023-08-08 00:50:56 -04:00
Wing Lian
21f445d763 more packing and dataset optimizations and fixes 2023-08-08 00:45:24 -04:00
Wing Lian
229b9165aa fix test and pylint checks 2023-08-07 09:38:05 -04:00
Wing Lian
394a65f11f add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test 2023-08-07 09:38:04 -04:00
Wing Lian
c70dae63cc add chatml 2023-08-07 09:38:04 -04:00
Wing Lian
7712955b35 fix chatml system prompt for openorca, legacy tokenizer opts 2023-08-07 09:38:04 -04:00
Wing Lian
f93f0017cd fix flash-attn, xformers, packing, support chatml 2023-08-07 09:38:04 -04:00
Wing Lian
0b01da0713 properly calculate max len 2023-08-07 09:38:04 -04:00
Wing Lian
b2f7bc7ccd use cumulative seq len with var len flash attn v2 w packing 2023-08-07 09:38:04 -04:00
Wing Lian
b8905e2a91 sample_packing_seq_len_multiplier config 2023-08-07 09:38:04 -04:00
Wing Lian
7e1edc662a make sure the chunk size is an int 2023-08-07 09:38:04 -04:00
Wing Lian
98c9bc69de seq_len_multiple for packing 2023-08-07 09:38:04 -04:00
Wing Lian
8378335dc9 limit packing to sequences of max seq len 2023-08-07 09:38:04 -04:00
Wing Lian
bdd34c7400 weighted CEL fixes 2023-08-07 09:38:04 -04:00
Wing Lian
c6cc54c7d9 weighted CE losses 2023-08-07 09:38:04 -04:00
Wing Lian
83f7362480 don't split batches when packing 2023-08-07 09:38:04 -04:00
Wing Lian
958d423e7c only process eval dataset for packing if not None 2023-08-07 09:38:04 -04:00
Wing Lian
e74eab6e73 add a test for the mask expansion for sequence packing 2023-08-07 09:38:04 -04:00
Wing Lian
487abfc769 pass sample packing efficiency to training args 2023-08-07 09:38:04 -04:00
Wing Lian
2bee646e85 fix step calc for packing 2023-08-07 09:38:04 -04:00
Wing Lian
945f2e5029 better handling so that all devices have the same dataloader len 2023-08-07 09:38:04 -04:00
Wing Lian
daed942fe9 fix rounding of len of batches to int 2023-08-07 09:38:04 -04:00
Wing Lian
df3eb645da better handling of variance in multipack dataloader length and trainer hanging when it runs out of data 2023-08-07 09:38:04 -04:00
Wing Lian
32fed7039d optimized expand mask fn 2023-08-07 09:38:04 -04:00
Wing Lian
7d7b5ebd71 more fixes for 4k and optimizations 2023-08-07 09:38:03 -04:00
Wing Lian
4b7ad9927f validation for sample packing and doc 2023-08-07 09:38:03 -04:00
Wing Lian
fedcf5a089 Update src/axolotl/utils/dataloader.py 2023-08-07 09:38:03 -04:00
Wing Lian
2f2974196d fix for position_ids w packing 2023-08-07 09:38:03 -04:00
Wing Lian
2e295c9f94 use accelerator prepare for dataloader 2023-08-07 09:38:03 -04:00
Wing Lian
4ab9ab79fd use distributed sampler, avoid accelerate prepare 2023-08-07 09:38:03 -04:00
Wing Lian
b02484a83e more fixes for sample packing 2023-08-07 09:38:03 -04:00
Wing Lian
58045f0816 more fixes, position_ids seems broken 2023-08-07 09:38:03 -04:00
Wing Lian
66774011c4 est total tokens, fix field loop 2023-08-07 09:38:03 -04:00
Wing Lian
41d4992029 more fixes for dataloader integration 2023-08-07 09:38:03 -04:00
Wing Lian
762f1b08db add position_ids back 2023-08-07 09:38:03 -04:00
Wing Lian
3aba4c5d7c use multi pack dataloader w random sampler 2023-08-07 09:38:03 -04:00
Wing Lian
ffd96839cf don't move masks to cpu 2023-08-07 09:38:03 -04:00
Wing Lian
ef9bf7ad73 fix expand mask for multiple batch items, make sure we pad position_ids 2023-08-07 09:38:03 -04:00
Wing Lian
4964b0d345 set position ids and use block diagonal attn mask 2023-08-07 09:38:03 -04:00
Wing Lian
36b0e30a9d fix attetion mask with packing 2023-08-07 09:38:03 -04:00
40 changed files with 258 additions and 522 deletions

13
.github/FUNDING.yml vendored
View File

@@ -1,13 +0,0 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -136,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt` - `sharegpt:chat`: conversations
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -225,10 +225,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny - `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json ```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
@@ -326,9 +322,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True # use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast: tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of N # resize the model embeddings when new tokens are added to multiples of 32
# multiples of 32 are reported to improve training speed on some models # this is reported to improve training speed on some models
resize_token_embeddings_multiple: resize_token_embeddings_to_32x:
# whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -364,9 +360,6 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path to push finetuned model hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -382,14 +375,10 @@ dataset_shard_idx:
sequence_len: 2048 sequence_len: 2048
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # soon to be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # use efficient multi-packing with block diagonal attention and per sequence position_ids
sample_packing: sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
@@ -415,12 +404,11 @@ lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode:
wandb_project: # your wandb project name wandb_project:
wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # set the name of your wandb run wandb_run_id:
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # 'checkpoint'
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -435,17 +423,13 @@ learning_rate: 0.00003
logging_steps: logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# group similarly sized data to minimize padding # don't use this, leads to wonky training (according to someone on the internet)
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -491,10 +475,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -526,9 +506,6 @@ torchdistx_path:
# Set padding for data collator to 'longest' # Set padding for data collator to 'longest'
collator_pad_to_longest: collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode # Debug mode
debug: debug:
@@ -548,14 +525,7 @@ Run
accelerate launch scripts/finetune.py configs/your_config.yml accelerate launch scripts/finetune.py configs/your_config.yml
``` ```
#### Multi-GPU #### Multi-GPU Config
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -570,18 +540,6 @@ fsdp_config:
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command - llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \ cd flash-attention && \
git checkout v2.0.4 && \ git checkout v2.0.1 && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \ cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -36,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -38,7 +38,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -33,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0001 learning_rate: 0.0001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -22,7 +22,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4 wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -18,7 +18,6 @@ lora_dropout:
lora_target_modules: lora_target_modules:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -15,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true max_packed_sequence_len: 4096
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -26,7 +26,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -39,7 +38,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false
@@ -49,8 +48,8 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention: true
flash_attention: true flash_attention:
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
@@ -64,3 +63,4 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,8 +18,7 @@ adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true max_packed_sequence_len: 4096
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -28,7 +27,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -41,7 +39,7 @@ lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: false tf32: false
@@ -51,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention: true
flash_attention: true flash_attention:
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
@@ -66,3 +64,4 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -28,7 +28,6 @@ lora_target_modules:
- o_proj - o_proj
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -35,7 +34,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -17,7 +17,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -21,7 +21,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- mlp_down - mlp_down
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -37,7 +37,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -1,6 +1,6 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict addict
fire fire
@@ -21,4 +21,3 @@ evaluate==0.4.0
rouge-score==0.1.2 rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml

View File

@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import barrier, is_main_process
@@ -29,6 +28,7 @@ from axolotl.utils.trainer import (
process_datasets_for_packing, process_datasets_for_packing,
setup_trainer, setup_trainer,
) )
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -43,6 +43,27 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]: def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ") print("Give me an instruction (Ctrl + D to finish): ")
instruction = "" instruction = ""
@@ -172,13 +193,36 @@ def train(
validate_config(cfg) validate_config(cfg)
normalize_config(cfg) # setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = load_tokenizer(cfg) LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
@@ -209,13 +253,7 @@ def train(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
barrier() barrier()
if cfg.max_steps: total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -231,10 +269,15 @@ def train(
return return
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") LOG.info("loading model and peft_config...")
model, peft_config = load_model(cfg, tokenizer) model, peft_config = load_model(
cfg.base_model,
safe_serialization = cfg.save_safetensors is True cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
@@ -243,11 +286,7 @@ def train(
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info("saving merged model") LOG.info("saving merged model")
model.save_pretrained( model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if cfg.inference: if cfg.inference:
@@ -262,7 +301,7 @@ def train(
return return
if "shard" in kwargs: if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
return return
trainer = setup_trainer( trainer = setup_trainer(
@@ -286,7 +325,7 @@ def train(
def terminate_handler(_, __, model): def terminate_handler(_, __, model):
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
sys.exit(0) sys.exit(0)
signal.signal( signal.signal(
@@ -313,7 +352,6 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -331,7 +369,7 @@ def train(
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset): class TokenizedPromptDataset(IterableDataset):
""" """
Dataset that returns tokenized prompts from a stream of text files. Iterable dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
**kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
super().__init__(self.process(dataset).data, **kwargs) self.dataset = dataset
def process(self, dataset): def __iter__(self):
features = dataset.features.keys() features = self.dataset.features.keys()
num_proc = min(64, os.cpu_count()) num_proc = os.cpu_count()
return dataset.map( return iter(
self.prompt_tokenizer.tokenize_prompt, self.dataset.map(
num_proc=num_proc, self.prompt_tokenizer.tokenize_prompt,
remove_columns=features, num_proc=num_proc,
remove_columns=features,
)
) )

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple
import torch import torch
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try: try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
@@ -92,8 +91,7 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif attention_mask.shape[0] == 1: else:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze() cu_q_lens = cu_q_lens.squeeze()
@@ -102,36 +100,6 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return ( return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")), self.o_proj(rearrange(output, "b s h d -> b s (h d)")),

View File

@@ -95,9 +95,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "SYSTEM: {system}\n" self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value: if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = ( self.turn_no_input_format = (

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE from axolotl.prompters import IGNORE_TOKEN_ID
@dataclass @dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801 conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2]
if sentence["value"]: if sentence["value"]:
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
yield conv yield conv

View File

@@ -271,11 +271,6 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
@@ -312,9 +307,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
raise IndexError( raise IndexError
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy() conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
@@ -334,7 +327,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_prompt():

View File

@@ -1,43 +0,0 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -1,6 +1,5 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
import logging
import os import os
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
@@ -12,10 +11,6 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl.callbacks")
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter""" """Callback to save the PEFT adapter"""
@@ -72,25 +67,3 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False control.should_save = False
return control return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,19 +1,14 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import itertools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import List, Tuple, Union
import torch import torch
from datasets import ( from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -270,12 +265,20 @@ def load_tokenized_prepared_datasets(
raise ValueError( raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}" f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
LOG.info("merging datasets") LOG.info("tokenizing, merging, and shuffling master dataset")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1: samples: List[int] = []
LOG.info("shuffle merged datasets") chunk_size = 1000
dataset = dataset.shuffle(seed=seed) for d in datasets:
d_iter = iter(d)
while True:
chunk = list(itertools.islice(d_iter, chunk_size))
if not chunk:
break
samples.extend(chunk)
LOG.info("shuffle")
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)

View File

@@ -3,7 +3,9 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
from typing import Any, Callable, List, Union import queue
import threading
from typing import Any, Callable, List, Optional, Union
import numba import numba
import numpy as np import numpy as np
@@ -78,7 +80,6 @@ def allocate(
s = 0 s = 0
start_index = 0 start_index = 0
result = [] result = []
result_totseqs = []
while True: while True:
# binary search [left, right) # binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank # add local rank
result.append(batch[rank]) result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs) yield batch[rank], tot_seqs, s, len(result) * c * n
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n): def chunk(iterable, n):
@@ -149,15 +148,14 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
total_num_tokens: Optional[int] = None,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
self.lengths = ( lengths_series = (
dataset.data.column("position_ids") dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
) )
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0 assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier assert batch_size >= sample_packing_seq_len_multiplier
@@ -172,11 +170,17 @@ class MultipackDistributedDataloader:
self.rank = 0 self.rank = 0
# statistics # statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -188,65 +192,83 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate( alloc_iter = iter(
lengths=lengths, allocate(
lengths_cumsum=lengths_cumsum, lengths=lengths,
rank=self.rank, lengths_cumsum=lengths_cumsum,
# c=self.batch_max_length, rank=self.rank,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier, # c=self.batch_max_length,
n=self.num_replicas, c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
) )
batches = [[indices[b_idx] for b_idx in batch] for batch in batches] for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
# statistics def _generate_batches_thread(self):
if set_stats: try:
self.eff_total_used += total_used self.generate_batches(set_stats=True)
self.eff_total_slots += total_slots except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
return batches, totseqs self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"): if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1 new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch) self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})") LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) # Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
for batches in chunk( while True:
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier batch = self.batch_queue.get()
): if batch is None: # Sentinel value received, stop iteration
break
chunked_data = [] chunked_data = []
attn_mask_cum_idx = 0 attn_mask_cum_idx = 0
for batch in batches: concatenated = {}
concatenated = {} batched_data = [self.dataset[batch_idx] for batch_idx in batch]
batched_data = [self.dataset[batch_idx] for batch_idx in batch] for feature in features:
for feature in features: if feature == "attention_mask":
if feature == "attention_mask": arrays = [
arrays = [ (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
(attn_mask_cum_idx + idx + 1) * np.array(item[feature]) for idx, item in enumerate(batched_data)
for idx, item in enumerate(batched_data) if feature in item
if feature in item ]
] attn_mask_cum_idx += len(batched_data)
attn_mask_cum_idx += len(batched_data) concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) else:
else: arrays = [
arrays = [ np.array(item[feature])
np.array(item[feature]) for item in batched_data
for item in batched_data if feature in item
if feature in item ]
] concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) chunked_data.append(concatenated)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data) yield self.collate_fn(chunked_data)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return break
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self): def _len_est(self):
lengths_sum = np.sum(self.lengths) if not self.total_num_tokens:
lengths_sum_per_device = lengths_sum // self.device_count self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -10,6 +10,3 @@ class DictDefault(Dict):
def __missing__(self, key): def __missing__(self, key):
return None return None
def __or__(self, other):
return DictDefault(super().__or__(other))

View File

@@ -22,7 +22,6 @@ from transformers import ( # noqa: F401
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -32,66 +31,37 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401 from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize( def load_tokenizer(
tokenizer: transformers.PreTrainedTokenizer, tokenizer_config,
model: transformers.PreTrainedModel, tokenizer_type,
resize_token_embeddings_multiple: Optional[int] = None, cfg,
): ):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {} tokenizer_kwargs = {}
use_fast = True # this is the default use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None: if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None: if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224 # True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
tokenizer_cls = AutoTokenizer LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
if cfg.tokenizer_type: LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [ if tokenizer.__class__.__name__ in [
"LlamaTokenizer", "LlamaTokenizer",
@@ -99,11 +69,6 @@ def load_tokenizer(cfg):
]: ]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -118,21 +83,19 @@ def load_tokenizer(cfg):
def load_model( def load_model(
cfg, tokenizer base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] ):
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model for a given configuration and tokenizer. Load a model from a base model and a model type.
""" """
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
"llama" in base_model "llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower()) or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model or cfg.is_llama_derived_model is True
) )
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -268,17 +231,10 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config = LlamaConfig.from_pretrained(base_model_config)
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@@ -313,7 +269,6 @@ def load_model(
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@@ -344,7 +299,6 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@@ -358,7 +312,6 @@ def load_model(
LOG.exception(err) LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@@ -366,25 +319,23 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
smart_tokenizer_and_embedding_resize( embeddings_len = (
tokenizer, math.ceil(len(tokenizer) / 32) * 32
model, if cfg.resize_token_embeddings_to_32x
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple, else len(tokenizer)
) )
model.resize_token_embeddings(embeddings_len)
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings and cfg.sequence_len >= model.config.max_position_embeddings
): ):
LOG.warning( LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -404,7 +355,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch_dtype) module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -443,9 +394,6 @@ def load_model(
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.transform(model) model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
@@ -22,7 +21,6 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
@@ -124,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1, default=1,
metadata={"help": "the multiplier for the max len for packed sequences"}, metadata={"help": "the multiplier for the max len for packed sequences"},
) )
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "the total number of tokens in the train dataset"},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -184,6 +186,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -206,6 +209,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -284,16 +288,13 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
@@ -301,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
( (
math.floor( math.floor(
0.99 0.99
* cfg.total_num_tokens * total_num_tokens
/ cfg.sample_packing_eff_est / cfg.sample_packing_eff_est
/ cfg.sequence_len / 2048
// cfg.batch_size // cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1)) // int(os.environ.get("WORLD_SIZE", 1))
) )
@@ -312,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
* cfg.num_epochs * cfg.num_epochs
) )
LOG.info( LOG.info(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
) )
else: else:
sampler = RandomSampler(train_dataset) sampler = RandomSampler(train_dataset)
@@ -344,7 +345,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info( LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
) )
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -440,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors: if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -451,17 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency" "sample_packing_efficiency"
] = cfg.sample_packing_eff_est ] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1, # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len, max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
@@ -471,7 +459,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy=evaluation_strategy, evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if cfg.save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps, save_steps=cfg.save_steps,
@@ -495,7 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine", else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs, **training_arguments_kwargs,
) )
@@ -567,7 +556,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = [] callbacks = []
callbacks.append(GPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(

View File

@@ -1,70 +1,12 @@
"""Module for working with config dicts""" """Module for validating config files"""
import logging import logging
import os
import torch import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
# in `accelerate launch`, we need to not pass through any device map and let
# accelerate figure out which parts of the model to put on which gpu
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
if accelerate_vars:
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg): def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing: if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError( raise ValueError(
@@ -168,13 +110,6 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention: if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError( raise ValueError(

View File

@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0: elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:

View File

@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
assert cfg.random_key is None, "DictDefault should return None for missing keys" assert cfg.random_key is None, "DictDefault should return None for missing keys"
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
def test_dict_nested_missingparentkey(self): def test_dict_nested_missingparentkey(self):
""" """
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist. Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.

View File

@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
""" """
def test_default_use_fast(self): def test_default_use_fast(self):
cfg = DictDefault( cfg = DictDefault({})
{ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__ assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self): def test_dont_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False, "tokenizer_use_fast": False,
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):