Compare commits
56 Commits
embeddings
...
packing-at
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64af21bcb2 | ||
|
|
6b5cf8b5ea | ||
|
|
79500f358a | ||
|
|
7e977a9b68 | ||
|
|
ac4b700daa | ||
|
|
2565c2f259 | ||
|
|
a07f432d9c | ||
|
|
57d9bf711c | ||
|
|
26983a1974 | ||
|
|
1b8747e319 | ||
|
|
035b3c760c | ||
|
|
17abbd59e1 | ||
|
|
6ec76ddb4c | ||
|
|
21d307b15b | ||
|
|
58e9dee204 | ||
|
|
4f7c04bae0 | ||
|
|
1162b93b6b | ||
|
|
21f445d763 | ||
|
|
229b9165aa | ||
|
|
394a65f11f | ||
|
|
c70dae63cc | ||
|
|
7712955b35 | ||
|
|
f93f0017cd | ||
|
|
0b01da0713 | ||
|
|
b2f7bc7ccd | ||
|
|
b8905e2a91 | ||
|
|
7e1edc662a | ||
|
|
98c9bc69de | ||
|
|
8378335dc9 | ||
|
|
bdd34c7400 | ||
|
|
c6cc54c7d9 | ||
|
|
83f7362480 | ||
|
|
958d423e7c | ||
|
|
e74eab6e73 | ||
|
|
487abfc769 | ||
|
|
2bee646e85 | ||
|
|
945f2e5029 | ||
|
|
daed942fe9 | ||
|
|
df3eb645da | ||
|
|
32fed7039d | ||
|
|
7d7b5ebd71 | ||
|
|
4b7ad9927f | ||
|
|
fedcf5a089 | ||
|
|
2f2974196d | ||
|
|
2e295c9f94 | ||
|
|
4ab9ab79fd | ||
|
|
b02484a83e | ||
|
|
58045f0816 | ||
|
|
66774011c4 | ||
|
|
41d4992029 | ||
|
|
762f1b08db | ||
|
|
3aba4c5d7c | ||
|
|
ffd96839cf | ||
|
|
ef9bf7ad73 | ||
|
|
4964b0d345 | ||
|
|
36b0e30a9d |
13
.github/FUNDING.yml
vendored
13
.github/FUNDING.yml
vendored
@@ -1,13 +0,0 @@
|
|||||||
# These are supported funding model platforms
|
|
||||||
|
|
||||||
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
|
||||||
patreon: # Replace with a single Patreon username
|
|
||||||
open_collective: # Replace with a single Open Collective username
|
|
||||||
ko_fi: # Replace with a single Ko-fi username
|
|
||||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
|
||||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
|
||||||
liberapay: # Replace with a single Liberapay username
|
|
||||||
issuehunt: # Replace with a single IssueHunt username
|
|
||||||
otechie: # Replace with a single Otechie username
|
|
||||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
|
||||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
|
||||||
66
README.md
66
README.md
@@ -136,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
- `sharegpt:chat`: conversations
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -225,10 +225,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
|
||||||
```json
|
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||||
@@ -326,9 +322,9 @@ tokenizer_type: AutoTokenizer
|
|||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||||
tokenizer_use_fast:
|
tokenizer_use_fast:
|
||||||
# resize the model embeddings when new tokens are added to multiples of N
|
# resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# multiples of 32 are reported to improve training speed on some models
|
# this is reported to improve training speed on some models
|
||||||
resize_token_embeddings_multiple:
|
resize_token_embeddings_to_32x:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -364,9 +360,6 @@ dataset_prepared_path: data/last_run_prepared
|
|||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # repo path
|
||||||
# push checkpoints to hub
|
# push checkpoints to hub
|
||||||
hub_model_id: # repo path to push finetuned model
|
hub_model_id: # repo path to push finetuned model
|
||||||
# how to push checkpoints to hub
|
|
||||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
|
||||||
hub_strategy:
|
|
||||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||||
# required to be true when used in combination with `push_dataset_to_hub`
|
# required to be true when used in combination with `push_dataset_to_hub`
|
||||||
hf_use_auth_token: # boolean
|
hf_use_auth_token: # boolean
|
||||||
@@ -382,14 +375,10 @@ dataset_shard_idx:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# max sequence length to concatenate training samples together up to
|
# max sequence length to concatenate training samples together up to
|
||||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||||
# FutureWarning: This will soon be DEPRECATED
|
# soon to be DEPRECATED
|
||||||
max_packed_sequence_len: 1024
|
max_packed_sequence_len: 1024
|
||||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
# use efficient multi-packing with block diagonal attention and per sequence position_ids
|
||||||
sample_packing:
|
sample_packing:
|
||||||
# you can set these packing optimizations AFTER starting a training at least once.
|
|
||||||
# The trainer will provide recommended values for these values.
|
|
||||||
sample_packing_eff_est:
|
|
||||||
total_num_tokens:
|
|
||||||
|
|
||||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
@@ -415,12 +404,11 @@ lora_out_dir:
|
|||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode:
|
||||||
wandb_project: # your wandb project name
|
wandb_project:
|
||||||
wandb_entity: # a wandb Team name if using a Team
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id: # set the name of your wandb run
|
wandb_run_id:
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # 'checkpoint'
|
||||||
|
|
||||||
# where to save the finished model to
|
# where to save the finished model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
@@ -435,17 +423,13 @@ learning_rate: 0.00003
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
save_steps:
|
save_steps:
|
||||||
eval_steps:
|
eval_steps:
|
||||||
save_total_limit: # checkpoints saved at a time
|
|
||||||
max_steps:
|
|
||||||
|
|
||||||
# save model as safetensors (require safetensors package)
|
# save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
# whether to mask out or include the human's prompt from the training labels
|
# whether to mask out or include the human's prompt from the training labels
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
# group similarly sized data to minimize padding
|
# don't use this, leads to wonky training (according to someone on the internet)
|
||||||
# may be slower to start, as it must download and sort the entire dataset
|
|
||||||
# note that training loss may have an oscillating pattern with this enabled
|
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
@@ -491,10 +475,6 @@ landmark_attention:
|
|||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
# llama only
|
# llama only
|
||||||
xpos_rope:
|
xpos_rope:
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
|
||||||
rope_scaling:
|
|
||||||
type: # linear | dynamic
|
|
||||||
factor: # float
|
|
||||||
|
|
||||||
# resume from a specific checkpoint dir
|
# resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -526,9 +506,6 @@ torchdistx_path:
|
|||||||
# Set padding for data collator to 'longest'
|
# Set padding for data collator to 'longest'
|
||||||
collator_pad_to_longest:
|
collator_pad_to_longest:
|
||||||
|
|
||||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
|
||||||
pretraining_dataset:
|
|
||||||
|
|
||||||
# Debug mode
|
# Debug mode
|
||||||
debug:
|
debug:
|
||||||
|
|
||||||
@@ -548,14 +525,7 @@ Run
|
|||||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multi-GPU
|
#### Multi-GPU Config
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Config
|
|
||||||
|
|
||||||
- llama FSDP
|
- llama FSDP
|
||||||
```yaml
|
```yaml
|
||||||
@@ -570,18 +540,6 @@ fsdp_config:
|
|||||||
|
|
||||||
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
||||||
|
|
||||||
##### Weights & Biases Logging
|
|
||||||
|
|
||||||
- wandb options
|
|
||||||
```yaml
|
|
||||||
wandb_mode:
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_run_id:
|
|
||||||
wandb_log_model:
|
|
||||||
```
|
|
||||||
|
|
||||||
### Inference
|
### Inference
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
Pass the appropriate flag to the train command:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|||||||
|
|
||||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||||
cd flash-attention && \
|
cd flash-attention && \
|
||||||
git checkout v2.0.4 && \
|
git checkout v2.0.1 && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
cd csrc/fused_dense_lib && \
|
cd csrc/fused_dense_lib && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ lora_target_modules:
|
|||||||
lora_target_linear:
|
lora_target_linear:
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
@@ -36,7 +35,7 @@ torchdistx_path:
|
|||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: true
|
||||||
bf16: true
|
bf16: true
|
||||||
fp16: false
|
fp16: false
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ lora_target_modules:
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ lora_target_linear: true
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ lora_target_modules:
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
@@ -33,7 +32,7 @@ torchdistx_path:
|
|||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0001
|
learning_rate: 0.0001
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: true
|
||||||
bf16: true
|
bf16: true
|
||||||
fp16: false
|
fp16: false
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
wandb_project: llama-7b-lora-int4
|
wandb_project: llama-7b-lora-int4
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ lora_dropout:
|
|||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
max_packed_sequence_len: 4096
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -26,7 +26,6 @@ lora_target_linear: true
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
@@ -39,7 +38,7 @@ lr_scheduler: cosine
|
|||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: true
|
||||||
bf16: true
|
bf16: true
|
||||||
fp16: false
|
fp16: false
|
||||||
tf32: false
|
tf32: false
|
||||||
@@ -49,8 +48,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention: true
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -64,3 +63,4 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
pad_token: "<pad>"
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ adapter: qlora
|
|||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
max_packed_sequence_len: 4096
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
@@ -28,7 +27,6 @@ lora_target_linear: true
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
@@ -41,7 +39,7 @@ lr_scheduler: cosine
|
|||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: true
|
||||||
bf16: true
|
bf16: true
|
||||||
fp16: false
|
fp16: false
|
||||||
tf32: false
|
tf32: false
|
||||||
@@ -51,8 +49,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention: true
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -66,3 +64,4 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
pad_token: "<pad>"
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
|||||||
lora_target_linear:
|
lora_target_linear:
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ lora_target_modules:
|
|||||||
- o_proj
|
- o_proj
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ lora_target_modules:
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
@@ -35,7 +34,7 @@ torchdistx_path:
|
|||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: true
|
||||||
bf16: true
|
bf16: true
|
||||||
fp16: false
|
fp16: false
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ lora_target_modules:
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ lora_target_modules:
|
|||||||
lora_target_linear:
|
lora_target_linear:
|
||||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ lora_target_modules:
|
|||||||
- mlp_down
|
- mlp_down
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ lora_target_linear: true
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git
|
transformers @ git+https://github.com/huggingface/transformers.git
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.39.0
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
@@ -21,4 +21,3 @@ evaluate==0.4.0
|
|||||||
rouge-score==0.1.2
|
rouge-score==0.1.2
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
@@ -29,6 +28,7 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
setup_trainer,
|
setup_trainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -43,6 +43,27 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def choose_device(cfg):
|
||||||
|
def get_device():
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return f"cuda:{cfg.local_rank}"
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
|
||||||
|
raise SystemError("No CUDA/mps device found")
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
cfg.device = get_device()
|
||||||
|
if cfg.device_map != "auto":
|
||||||
|
if cfg.device.startswith("cuda"):
|
||||||
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
|
else:
|
||||||
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
@@ -172,13 +193,36 @@ def train(
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
# setup some derived config / hyperparams
|
||||||
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
|
cfg.batch_size = (
|
||||||
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
choose_device(cfg)
|
||||||
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
|
if cfg.ddp:
|
||||||
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
|
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
if cfg.device == "mps":
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.tf32 = False
|
||||||
|
if cfg.bf16:
|
||||||
|
cfg.fp16 = True
|
||||||
|
cfg.bf16 = False
|
||||||
|
|
||||||
|
if cfg.tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
tokenizer = load_tokenizer(cfg)
|
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||||
|
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
@@ -209,13 +253,7 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
barrier()
|
barrier()
|
||||||
if cfg.max_steps:
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
total_num_steps = min(
|
|
||||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
|
||||||
)
|
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
|
||||||
else:
|
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
@@ -231,10 +269,15 @@ def train(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and peft_config...")
|
||||||
model, peft_config = load_model(cfg, tokenizer)
|
model, peft_config = load_model(
|
||||||
|
cfg.base_model,
|
||||||
safe_serialization = cfg.save_safetensors is True
|
cfg.base_model_config,
|
||||||
|
cfg.model_type,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
adapter=cfg.adapter,
|
||||||
|
)
|
||||||
|
|
||||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
@@ -243,11 +286,7 @@ def train(
|
|||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info("saving merged model")
|
LOG.info("saving merged model")
|
||||||
model.save_pretrained(
|
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
|
||||||
safe_serialization=safe_serialization,
|
|
||||||
)
|
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if cfg.inference:
|
if cfg.inference:
|
||||||
@@ -262,7 +301,7 @@ def train(
|
|||||||
return
|
return
|
||||||
|
|
||||||
if "shard" in kwargs:
|
if "shard" in kwargs:
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir)
|
||||||
return
|
return
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
@@ -286,7 +325,7 @@ def train(
|
|||||||
def terminate_handler(_, __, model):
|
def terminate_handler(_, __, model):
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
signal.signal(
|
signal.signal(
|
||||||
@@ -313,7 +352,6 @@ def train(
|
|||||||
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(cfg.output_dir)
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
@@ -331,7 +369,7 @@ def train(
|
|||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import IterableDataset
|
||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class TokenizedPromptDataset(Dataset):
|
class TokenizedPromptDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
Dataset that returns tokenized prompts from a stream of text files.
|
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||||
Args:
|
Args:
|
||||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||||
dataset (dataset.Dataset): Dataset with text files.
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
prompt_tokenizer: PromptTokenizingStrategy,
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
dataset: IterableDataset,
|
dataset: IterableDataset,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
self.dataset = dataset
|
||||||
|
|
||||||
def process(self, dataset):
|
def __iter__(self):
|
||||||
features = dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
num_proc = min(64, os.cpu_count())
|
num_proc = os.cpu_count()
|
||||||
return dataset.map(
|
return iter(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.dataset.map(
|
||||||
num_proc=num_proc,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
remove_columns=features,
|
num_proc=num_proc,
|
||||||
|
remove_columns=features,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
@@ -92,8 +91,7 @@ def forward(
|
|||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif attention_mask.shape[0] == 1:
|
else:
|
||||||
# special handling using sample packing
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
cu_q_lens = cu_q_lens.squeeze()
|
cu_q_lens = cu_q_lens.squeeze()
|
||||||
@@ -102,36 +100,6 @@ def forward(
|
|||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
else:
|
|
||||||
nheads = qkv.shape[-2]
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
|
||||||
x_unpad = rearrange(
|
|
||||||
x_unpad,
|
|
||||||
"nnz (three h d) -> nnz three h d",
|
|
||||||
three=3,
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
x_unpad,
|
|
||||||
cu_q_lens,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = rearrange(
|
|
||||||
pad_input(
|
|
||||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
|
||||||
indices,
|
|
||||||
bsz,
|
|
||||||
q_len,
|
|
||||||
),
|
|
||||||
"b s (h d) -> b s h d",
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||||
|
|||||||
@@ -95,9 +95,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
|||||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "User: {instruction}\nAssistant:"
|
||||||
self.system_format = "SYSTEM: {system}\n"
|
self.system_format = "System: {system}\n"
|
||||||
if self.prompt_style == PromptStyle.CHATML.value:
|
if self.prompt_style == PromptStyle.CHATML.value:
|
||||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
self.turn_no_input_format = (
|
self.turn_no_input_format = (
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Generator, List, Sequence
|
from typing import Generator, List, Sequence
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = [] # pylint: disable=R0801
|
conv.messages = [] # pylint: disable=R0801
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
assert role == conv.roles[j % 2]
|
||||||
if sentence["value"]:
|
if sentence["value"]:
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
yield conv
|
yield conv
|
||||||
|
|||||||
@@ -271,11 +271,6 @@ class Conversation:
|
|||||||
self.messages.append([role, message])
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|
||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
@@ -312,9 +307,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
raise IndexError(
|
raise IndexError
|
||||||
f"A conversation entry has less than 2 messages :\n{source}"
|
|
||||||
)
|
|
||||||
|
|
||||||
conv = self._conversation.copy()
|
conv = self._conversation.copy()
|
||||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||||
@@ -334,7 +327,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = []
|
conv.messages = []
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
assert role == conv.roles[j % 2]
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_prompt():
|
for part in conv.get_prompt():
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
|
||||||
|
|
||||||
import pynvml
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage(device=0):
|
|
||||||
return torch.cuda.memory_allocated(device) / 1024.0**3
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage_all(device=0):
|
|
||||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
|
||||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
|
||||||
smi = gpu_memory_usage_smi(device)
|
|
||||||
return usage, reserved - usage, max(0, smi - reserved)
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage_smi(device=0):
|
|
||||||
if isinstance(device, torch.device):
|
|
||||||
device = device.index
|
|
||||||
if isinstance(device, str) and device.startswith("cuda:"):
|
|
||||||
device = int(device[5:])
|
|
||||||
|
|
||||||
pynvml.nvmlInit()
|
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
||||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
||||||
return info.used / 1024.0**3
|
|
||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
return (0, 0, 0)
|
|
||||||
|
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
|
||||||
extras = []
|
|
||||||
if cache > 0:
|
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
|
||||||
if misc > 0:
|
|
||||||
extras.append(f"+{misc:.03f}GB misc")
|
|
||||||
log.info(
|
|
||||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
|
||||||
)
|
|
||||||
return usage, cache, misc
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Callbacks for Trainer class"""
|
"""Callbacks for Trainer class"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -12,10 +11,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
"""Callback to save the PEFT adapter"""
|
"""Callback to save the PEFT adapter"""
|
||||||
@@ -72,25 +67,3 @@ class SaveBetterTransformerModelCallback(
|
|||||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||||
control.should_save = False
|
control.should_save = False
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class GPUStatsCallback(
|
|
||||||
TrainerCallback
|
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
|
||||||
"""Callback to track GPU utilization"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self.logged = False
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if not self.logged and state.global_step > 1:
|
|
||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
|
||||||
self.logged = True
|
|
||||||
return control
|
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
"""Module containing data utilities"""
|
"""Module containing data utilities"""
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
Dataset,
|
|
||||||
DatasetDict,
|
|
||||||
concatenate_datasets,
|
|
||||||
load_dataset,
|
|
||||||
load_from_disk,
|
|
||||||
)
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@@ -270,12 +265,20 @@ def load_tokenized_prepared_datasets(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||||
)
|
)
|
||||||
LOG.info("merging datasets")
|
LOG.info("tokenizing, merging, and shuffling master dataset")
|
||||||
dataset = concatenate_datasets(datasets)
|
|
||||||
|
|
||||||
if len(datasets) > 1:
|
samples: List[int] = []
|
||||||
LOG.info("shuffle merged datasets")
|
chunk_size = 1000
|
||||||
dataset = dataset.shuffle(seed=seed)
|
for d in datasets:
|
||||||
|
d_iter = iter(d)
|
||||||
|
while True:
|
||||||
|
chunk = list(itertools.islice(d_iter, chunk_size))
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
samples.extend(chunk)
|
||||||
|
|
||||||
|
LOG.info("shuffle")
|
||||||
|
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ import hashlib
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, List, Union
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,7 +80,6 @@ def allocate(
|
|||||||
s = 0
|
s = 0
|
||||||
start_index = 0
|
start_index = 0
|
||||||
result = []
|
result = []
|
||||||
result_totseqs = []
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# binary search [left, right)
|
# binary search [left, right)
|
||||||
@@ -104,10 +105,8 @@ def allocate(
|
|||||||
|
|
||||||
# add local rank
|
# add local rank
|
||||||
result.append(batch[rank])
|
result.append(batch[rank])
|
||||||
# add total seqs for all ranks
|
|
||||||
result_totseqs.append(tot_seqs)
|
yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
|
||||||
return result, result_totseqs, s, len(result) * c * n
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(iterable, n):
|
def chunk(iterable, n):
|
||||||
@@ -149,15 +148,14 @@ class MultipackDistributedDataloader:
|
|||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
device_count: int = 1,
|
device_count: int = 1,
|
||||||
|
total_num_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.lengths = (
|
lengths_series = (
|
||||||
dataset.data.column("position_ids")
|
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
)
|
)
|
||||||
|
self.lengths: np.ndarray = lengths_series.values
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||||
assert batch_size >= sample_packing_seq_len_multiplier
|
assert batch_size >= sample_packing_seq_len_multiplier
|
||||||
@@ -172,11 +170,17 @@ class MultipackDistributedDataloader:
|
|||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
|
||||||
# statistics
|
# statistics
|
||||||
|
self.total_num_tokens = total_num_tokens
|
||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
self.device_count = device_count
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# for non-blocking batch creation
|
||||||
|
self.batch_queue: queue.Queue = queue.Queue(
|
||||||
|
maxsize=10
|
||||||
|
) # Adjust maxsize as needed
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
LOG.info("generating packed batches")
|
LOG.info("generating packed batches")
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -188,65 +192,83 @@ class MultipackDistributedDataloader:
|
|||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
batches, totseqs, total_used, total_slots = allocate(
|
alloc_iter = iter(
|
||||||
lengths=lengths,
|
allocate(
|
||||||
lengths_cumsum=lengths_cumsum,
|
lengths=lengths,
|
||||||
rank=self.rank,
|
lengths_cumsum=lengths_cumsum,
|
||||||
# c=self.batch_max_length,
|
rank=self.rank,
|
||||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
# c=self.batch_max_length,
|
||||||
n=self.num_replicas,
|
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||||
|
n=self.num_replicas,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
for batch, tot_seqs, total_used, total_slots in alloc_iter:
|
||||||
|
self.batch_queue.put([indices[b_idx] for b_idx in batch])
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used = total_used
|
||||||
|
self.eff_total_slots = total_slots
|
||||||
|
self.batch_queue.put(None) # Signal the end of batch generation
|
||||||
|
|
||||||
# statistics
|
def _generate_batches_thread(self):
|
||||||
if set_stats:
|
try:
|
||||||
self.eff_total_used += total_used
|
self.generate_batches(set_stats=True)
|
||||||
self.eff_total_slots += total_slots
|
except Exception as e:
|
||||||
|
LOG.error(f"Error in batch generation thread: {e}")
|
||||||
return batches, totseqs
|
self.batch_queue.put(
|
||||||
|
None
|
||||||
|
) # Signal the end of batch generation in case of error
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if hasattr(self.sampler, "set_epoch"):
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
new_epoch = self.sampler.epoch + 1
|
new_epoch = self.sampler.epoch + 1
|
||||||
self.sampler.set_epoch(new_epoch)
|
self.sampler.set_epoch(new_epoch)
|
||||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
# Start the batch generation in a separate thread
|
||||||
|
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
|
||||||
|
batch_gen_thread.start()
|
||||||
|
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
for batches in chunk(
|
while True:
|
||||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
batch = self.batch_queue.get()
|
||||||
):
|
if batch is None: # Sentinel value received, stop iteration
|
||||||
|
break
|
||||||
chunked_data = []
|
chunked_data = []
|
||||||
attn_mask_cum_idx = 0
|
attn_mask_cum_idx = 0
|
||||||
for batch in batches:
|
concatenated = {}
|
||||||
concatenated = {}
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
for feature in features:
|
||||||
for feature in features:
|
if feature == "attention_mask":
|
||||||
if feature == "attention_mask":
|
arrays = [
|
||||||
arrays = [
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
for idx, item in enumerate(batched_data)
|
||||||
for idx, item in enumerate(batched_data)
|
if feature in item
|
||||||
if feature in item
|
]
|
||||||
]
|
attn_mask_cum_idx += len(batched_data)
|
||||||
attn_mask_cum_idx += len(batched_data)
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
else:
|
||||||
else:
|
arrays = [
|
||||||
arrays = [
|
np.array(item[feature])
|
||||||
np.array(item[feature])
|
for item in batched_data
|
||||||
for item in batched_data
|
if feature in item
|
||||||
if feature in item
|
]
|
||||||
]
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
chunked_data.append(concatenated)
|
||||||
chunked_data.append(concatenated)
|
|
||||||
yield self.collate_fn(chunked_data)
|
yield self.collate_fn(chunked_data)
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
break
|
||||||
|
# Wait for the batch generation thread to finish
|
||||||
|
batch_gen_thread.join(timeout=5)
|
||||||
|
LOG.info(f"actual packing efficiency: {self.efficiency()}")
|
||||||
|
|
||||||
def _len_est(self):
|
def _len_est(self):
|
||||||
lengths_sum = np.sum(self.lengths)
|
if not self.total_num_tokens:
|
||||||
lengths_sum_per_device = lengths_sum // self.device_count
|
self.total_num_tokens = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = self.total_num_tokens // self.device_count
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
|||||||
@@ -10,6 +10,3 @@ class DictDefault(Dict):
|
|||||||
|
|
||||||
def __missing__(self, key):
|
def __missing__(self, key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __or__(self, other):
|
|
||||||
return DictDefault(super().__or__(other))
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -32,66 +31,37 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def smart_tokenizer_and_embedding_resize(
|
def load_tokenizer(
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer_config,
|
||||||
model: transformers.PreTrainedModel,
|
tokenizer_type,
|
||||||
resize_token_embeddings_multiple: Optional[int] = None,
|
cfg,
|
||||||
):
|
):
|
||||||
"""Resize tokenizer and embedding.
|
|
||||||
|
|
||||||
Note: This function resizes the tokenizer to accommodate additional special tokens and the
|
|
||||||
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
|
|
||||||
have been added, the function computes the average embedding values of the existing embeddings
|
|
||||||
and sets those values for the new special token embeddings. This is done separately for the input
|
|
||||||
embeddings and output embeddings of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
old_tokens = model.get_input_embeddings().weight.data.shape[0]
|
|
||||||
num_new_tokens = len(tokenizer) - old_tokens
|
|
||||||
embeddings_len = (
|
|
||||||
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
|
|
||||||
* resize_token_embeddings_multiple
|
|
||||||
if resize_token_embeddings_multiple
|
|
||||||
else len(tokenizer)
|
|
||||||
)
|
|
||||||
model.resize_token_embeddings(embeddings_len)
|
|
||||||
|
|
||||||
if num_new_tokens > 0:
|
|
||||||
input_embeddings = model.get_input_embeddings().weight.data
|
|
||||||
output_embeddings = model.get_output_embeddings().weight.data
|
|
||||||
|
|
||||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
||||||
dim=0, keepdim=True
|
|
||||||
)
|
|
||||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
||||||
dim=0, keepdim=True
|
|
||||||
)
|
|
||||||
|
|
||||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
||||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
|
|
||||||
if cfg.tokenizer_use_fast is not None:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
use_fast = cfg.tokenizer_use_fast
|
use_fast = cfg.tokenizer_use_fast
|
||||||
if cfg.tokenizer_legacy is not None:
|
if cfg.tokenizer_legacy is not None:
|
||||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
|
if tokenizer_type:
|
||||||
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||||
|
tokenizer_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_cls = AutoTokenizer
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
if cfg.tokenizer_type:
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
|
||||||
tokenizer_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in [
|
if tokenizer.__class__.__name__ in [
|
||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
@@ -99,11 +69,6 @@ def load_tokenizer(cfg):
|
|||||||
]:
|
]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -118,21 +83,19 @@ def load_tokenizer(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg, tokenizer
|
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||||
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
):
|
||||||
|
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model from a base model and a model type.
|
||||||
"""
|
"""
|
||||||
base_model = cfg.base_model
|
|
||||||
base_model_config = cfg.base_model_config
|
|
||||||
model_type = cfg.model_type
|
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
"llama" in base_model
|
"llama" in base_model
|
||||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model is True
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
@@ -268,17 +231,10 @@ def load_model(
|
|||||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
config_kwargs = {}
|
config = LlamaConfig.from_pretrained(base_model_config)
|
||||||
if cfg.rope_scaling:
|
|
||||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
|
||||||
config = LlamaConfig.from_pretrained(
|
|
||||||
base_model_config,
|
|
||||||
**config_kwargs,
|
|
||||||
)
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -313,7 +269,6 @@ def load_model(
|
|||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -344,7 +299,6 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -358,7 +312,6 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -366,25 +319,23 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
smart_tokenizer_and_embedding_resize(
|
embeddings_len = (
|
||||||
tokenizer,
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
model,
|
if cfg.resize_token_embeddings_to_32x
|
||||||
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
|
else len(tokenizer)
|
||||||
)
|
)
|
||||||
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||||
)
|
)
|
||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
|
||||||
|
|
||||||
if not cfg.gptq and (
|
if not cfg.gptq and (
|
||||||
(cfg.adapter == "lora" and load_in_8bit)
|
(cfg.adapter == "lora" and load_in_8bit)
|
||||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
@@ -404,7 +355,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch_dtype)
|
module.to(torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -443,9 +394,6 @@ def load_model(
|
|||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|
||||||
if cfg.adapter is not None:
|
|
||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import numpy as np
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, set_caching_enabled
|
from datasets import Dataset, set_caching_enabled
|
||||||
@@ -22,7 +21,6 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GPUStatsCallback,
|
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
@@ -124,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||||
)
|
)
|
||||||
|
train_data_total_num_tokens: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the total number of tokens in the train dataset"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -184,6 +186,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
total_num_tokens=self.args.train_data_total_num_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -206,6 +209,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
total_num_tokens=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
@@ -284,16 +288,13 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
|
total_num_tokens = (
|
||||||
|
cfg.total_num_tokens
|
||||||
|
if cfg.total_num_tokens
|
||||||
|
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
|
)
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens:
|
||||||
LOG.info("calculating total_num_tokens")
|
|
||||||
total_num_tokens = np.sum(
|
|
||||||
train_dataset.data.column("input_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
|
||||||
.values
|
|
||||||
)
|
|
||||||
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||||
cfg.total_num_tokens = total_num_tokens
|
|
||||||
|
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
@@ -301,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
(
|
(
|
||||||
math.floor(
|
math.floor(
|
||||||
0.99
|
0.99
|
||||||
* cfg.total_num_tokens
|
* total_num_tokens
|
||||||
/ cfg.sample_packing_eff_est
|
/ cfg.sample_packing_eff_est
|
||||||
/ cfg.sequence_len
|
/ 2048
|
||||||
// cfg.batch_size
|
// cfg.batch_size
|
||||||
// int(os.environ.get("WORLD_SIZE", 1))
|
// int(os.environ.get("WORLD_SIZE", 1))
|
||||||
)
|
)
|
||||||
@@ -312,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampler = RandomSampler(train_dataset)
|
sampler = RandomSampler(train_dataset)
|
||||||
@@ -344,7 +345,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||||
)
|
)
|
||||||
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
|
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
@@ -440,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
training_arguments_kwargs["push_to_hub"] = True
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
|
|
||||||
if cfg.hub_strategy:
|
|
||||||
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
|
||||||
|
|
||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
@@ -451,17 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if cfg.val_set_size == 0:
|
|
||||||
evaluation_strategy = "no"
|
|
||||||
elif cfg.eval_steps < 1:
|
|
||||||
# eval every epoch
|
|
||||||
evaluation_strategy = "epoch"
|
|
||||||
else:
|
|
||||||
# eval every eval_steps steps
|
|
||||||
evaluation_strategy = "steps"
|
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
@@ -471,7 +459,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
evaluation_strategy=evaluation_strategy,
|
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
@@ -495,7 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
else "cosine",
|
else "cosine",
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
|
||||||
|
train_data_total_num_tokens=cfg.total_num_tokens,
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -567,7 +556,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(GPUStatsCallback(cfg))
|
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
if cfg.early_stopping_patience:
|
if cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
|||||||
@@ -1,70 +1,12 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for validating config files"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
|
||||||
def get_device():
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return f"cuda:{cfg.local_rank}"
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
|
|
||||||
raise SystemError("No CUDA/mps device found")
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
cfg.device = get_device()
|
|
||||||
if cfg.device_map != "auto":
|
|
||||||
if cfg.device.startswith("cuda"):
|
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
|
||||||
else:
|
|
||||||
cfg.device_map = {"": cfg.device}
|
|
||||||
|
|
||||||
# in `accelerate launch`, we need to not pass through any device map and let
|
|
||||||
# accelerate figure out which parts of the model to put on which gpu
|
|
||||||
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
|
||||||
if accelerate_vars:
|
|
||||||
cfg.device_map = None
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_config(cfg):
|
|
||||||
# setup some derived config / hyperparams
|
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
||||||
cfg.batch_size // cfg.micro_batch_size
|
|
||||||
)
|
|
||||||
cfg.batch_size = (
|
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
choose_device(cfg)
|
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
|
||||||
if cfg.ddp:
|
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
|
||||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
|
||||||
|
|
||||||
if cfg.device == "mps":
|
|
||||||
cfg.load_in_8bit = False
|
|
||||||
cfg.tf32 = False
|
|
||||||
if cfg.bf16:
|
|
||||||
cfg.fp16 = True
|
|
||||||
cfg.bf16 = False
|
|
||||||
else:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -168,13 +110,6 @@ def validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.model_revision:
|
|
||||||
raise ValueError(
|
|
||||||
"model_revision is not supported for GPTQ models. "
|
|
||||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
|
||||||
+ "point to its path, and remove model_revision from the config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.sdp_attention:
|
if cfg.sample_packing and cfg.sdp_attention:
|
||||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
|
|||||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
|
||||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
|
||||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||||
|
|||||||
@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
|
|||||||
|
|
||||||
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
||||||
|
|
||||||
def test_dict_or(self):
|
|
||||||
cfg = DictDefault({}) | DictDefault({})
|
|
||||||
|
|
||||||
assert (
|
|
||||||
cfg.random_key is None
|
|
||||||
), "DictDefault should return None for missing keys after | operation"
|
|
||||||
|
|
||||||
def test_dict_nested_missingparentkey(self):
|
def test_dict_nested_missingparentkey(self):
|
||||||
"""
|
"""
|
||||||
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
||||||
|
|||||||
@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault({})
|
||||||
{
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
assert "Fast" in tokenizer.__class__.__name__
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
def test_dont_use_fast(self):
|
def test_dont_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"tokenizer_use_fast": False,
|
"tokenizer_use_fast": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user