Compare commits

..

12 Commits

Author SHA1 Message Date
Wing Lian
718a8f4153 update flash attention to 2.5.5 for gemma 2024-02-21 23:32:44 -05:00
NanoCode012
a359579371 deprecate: pytorch 2.0.1 image (#1315) [skip ci]
* deprecate: pytorch 2.0.1 image

* deprecate from main image

* Update main.yml

* Update tests.yml
2024-02-22 11:39:47 +09:00
Wing Lian
2752d5f958 multipack for gemma (#1313)
* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling
2024-02-21 19:24:21 -05:00
Monk
9e300aca0c Adding Google's gemma Model (#1312) 2024-02-21 12:56:47 -05:00
NanoCode012
3d2cd804ae fix(readme): update inference md link (#1311) [skip ci] 2024-02-22 02:48:06 +09:00
Jared Palmer
6ab69ec5f8 Add instructions for playing with qlora model to colab example (#1290)
* Add instructions for playing with qlora model to colab example

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>
2024-02-22 02:46:27 +09:00
David Meikle
3c00f406d6 Allow load_best_model_at_end to be configured for early stopping on custom evaluation datasets (#1291)
* Allow load_best_model_at_end when using test_datasets and val_set_size is zero for custom evaluation datasets

* Fixed formatting following failed Lint check
2024-02-22 00:57:18 +09:00
NanoCode012
a7a9a1433a fix(examples): remove is_*_derived as it's parsed automatically (#1297) 2024-02-22 00:52:46 +09:00
Leonardo Emili
e2786cce6a Validation always happens on first step (#1300) 2024-02-22 00:52:24 +09:00
Leonardo Emili
5a5d47458d Add seq2seq eval benchmark callback (#1274)
* Add CausalLMBenchEvalCallback for measuring seq2seq performance

* Fix code for pre-commit

* Fix typing and improve logging

* eval_sample_packing must be false with CausalLMBenchEvalCallback
2024-02-13 08:24:30 -08:00
김진원
8430db22e2 Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273) 2024-02-12 21:23:28 -08:00
Wing Lian
4b997c3e1a allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
* allow the optimizer prune ration for relora to be configurable

* update docs for relora

* prevent circular imports
2024-02-12 11:39:51 -08:00
48 changed files with 526 additions and 127 deletions

View File

@@ -12,11 +12,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"

View File

@@ -13,11 +13,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
@@ -73,11 +68,6 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"

View File

@@ -69,7 +69,7 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.1.2
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"

View File

@@ -34,7 +34,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Inference](#inference) - [Inference](#inference-playground)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens) - [Special Tokens](#special-tokens)
- Advanced Topics - Advanced Topics
@@ -734,6 +734,8 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
@@ -782,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -811,6 +814,7 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim # For one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # Learning rate div factor

View File

@@ -2,7 +2,6 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -177,6 +177,24 @@
"# Buy using the ! the comand will be executed as a bash command\n", "# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
} }
], ],
"metadata": { "metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
# enable 4bit for QLoRA # enable 4bit for QLoRA
load_in_4bit: true load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

65
examples/gemma/qlora.yml Normal file
View File

@@ -0,0 +1,65 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false

View File

@@ -1,5 +1,4 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_disable_exllama: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -60,7 +59,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -57,7 +56,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,7 +2,6 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -61,7 +60,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3 #default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -49,7 +48,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed_configs/zero2.json deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
@@ -68,7 +67,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: true load_in_8bit: true
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -2,7 +2,6 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -10,9 +9,9 @@ strict: false
max_steps: 200 max_steps: 200
pretraining_dataset: pretraining_dataset:
- path: c4 path: c4
name: en name: en
type: pretrain type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./model-out

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,8 +1,7 @@
base_model: 01-ai/Yi-34B-Chat base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -29,7 +28,7 @@ num_epochs: 1
val_set_size: 0.1 val_set_size: 0.1
evals_per_epoch: 5 evals_per_epoch: 5
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_max_new_tokens: 128
eval_sample_packing: false eval_sample_packing: false
eval_batch_size: 1 eval_batch_size: 1

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.26.1 accelerate==0.26.1
@@ -11,7 +11,7 @@ fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets>=2.15.0 datasets>=2.15.0
flash-attn==2.3.3 flash-attn==2.5.5
sentencepiece sentencepiece
wandb wandb
einops einops
@@ -23,7 +23,7 @@ numba
numpy>=1.24.4 numpy>=1.24.4
mlflow mlflow
# qlora things # qlora things
evaluate==0.4.0 evaluate==0.4.1
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml

View File

@@ -67,7 +67,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.0", "flash-attn==2.5.5",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",

View File

@@ -38,6 +38,7 @@ from axolotl.utils.callbacks import (
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.collators import ( from axolotl.utils.collators import (
@@ -50,6 +51,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
) )
try: try:
@@ -131,6 +133,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
) )
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field( bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"} default="eval", metadata={"help": "The benchmark split to run on"}
) )
@@ -143,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field( do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."} default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
) )
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field( max_bench_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
@@ -160,6 +169,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
) )
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -217,6 +232,16 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -643,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience: if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(
@@ -791,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset: if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs[
"metric_for_best_model" "metric_for_best_model"
@@ -851,8 +883,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience or self.cfg.early_stopping_patience
) )
and not self.cfg.test_datasets and (
and self.cfg.val_set_size > 0 (not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps and self.cfg.save_steps
and self.cfg.eval_steps and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0 and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -883,6 +917,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -900,9 +937,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps if self.cfg.relora_steps:
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
) )

View File

@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -425,7 +427,9 @@ def flashattn_forward(
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -688,6 +692,9 @@ def llama_model_forward(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = ( output_attentions = (
output_attentions output_attentions

View File

@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"] SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
def patch_for_multipack(model_type): def patch_for_multipack(model_type):
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -46,8 +46,9 @@ def reset_optimizer(
*, *,
reset_params: list[str], # where str is the key to a torch.nn.Parameter reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str], optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9) pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0 n_zeros = 0
n_total = 0 n_total = 0
@@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback):
optimizer, optimizer,
reset_params=lora_params, reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys, optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
) )
if self.quantized: if self.quantized:

View File

@@ -1,28 +0,0 @@
import os
from typing import Callable, Generator, Tuple
import psycopg
import psycopg.conninfo
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
pgsql_conn = os.environ.get("PGSQL_CONN", None)
if not pgsql_conn:
raise ValueError("missing PGSQL_CONN environment variable")
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
def data_generator() -> Generator[Tuple, None, None]:
with psycopg.connect(**conn_dict) as conn:
with conn.cursor() as cur:
page_size = 10
last_id = None
while True:
if last_id:
where_clause = f" WHERE {id_field} > {last_id}"
cur.execute(
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
)
for row in cur.fetchall():
yield row[id_field], dict(row)
return data_generator

View File

@@ -62,7 +62,6 @@ class EvalFirstStepCallback(
): ):
if ( if (
args.evaluation_strategy == IntervalStrategy.STEPS args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0
and state.global_step == 1 and state.global_step == 1
): ):
control.should_evaluate = True control.should_evaluate = True
@@ -361,6 +360,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
return BenchEvalCallback return BenchEvalCallback
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score
def evaluate_preds(sources, predictions, references):
scores = {}
for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer): def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback): class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation""" """Callback to log prediction values during each evaluation"""
@@ -388,7 +568,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
generation_config = GenerationConfig( generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens, max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id, bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,

View File

@@ -56,7 +56,13 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0 cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp: if cfg.ddp:
@@ -550,6 +556,21 @@ def validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer: if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}") raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -1,7 +1,6 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib import hashlib
import importlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@@ -12,12 +11,10 @@ import yaml
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
IterableDataset,
concatenate_datasets, concatenate_datasets,
load_dataset, load_dataset,
load_from_disk, load_from_disk,
) )
from datasets.iterable_dataset import ExamplesIterable
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
@@ -67,25 +64,6 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def get_streaming_dataset(ds_cfg):
path = ds_cfg["path"]
func = None
try:
load_fn = path.split(".")[-1]
module_name = ".".join(load_fn.split(".")[:-1])
mod = importlib.import_module(f".{module_name}", "axolotl")
func = getattr(mod, load_fn)
except Exception:
pass
if func:
data_producer = func(**ds_cfg)
return IterableDataset(ExamplesIterable(data_producer, {}))
else:
split = ds_cfg["split"] or "train"
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
@@ -102,6 +80,14 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -111,7 +97,7 @@ def prepare_dataset(cfg, tokenizer):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
get_streaming_dataset(cfg.pretraining_dataset[0]), load_dataset(path, streaming=True, split="train", name=name),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,

View File

@@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
*, *,
num_warmup_steps: int, num_warmup_steps: int,
num_training_steps: int, num_training_steps: int,
num_cycles: float num_cycles: float,
): ):
if current_step < num_warmup_steps: if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
@@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
*, *,
num_warmup_steps: int, num_warmup_steps: int,
num_training_steps: int, num_training_steps: int,
min_lr_ratio: float min_lr_ratio: float,
): ):
# Warm up # Warm up
if current_step < num_warmup_steps: if current_step < num_warmup_steps:
@@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr(
min_lr_ratio=min_lr_ratio, min_lr_ratio=min_lr_ratio,
) )
return LambdaLR(optimizer, lr_lambda) return LambdaLR(optimizer, lr_lambda)
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_constant_steps = int(num_training_steps * constant_lr_ratio)
current_step = min(current_step, num_constant_steps)
progress = float(current_step - num_warmup_steps) / float(
max(1, num_constant_steps - num_warmup_steps)
)
return (
max(
0,
(1 - min_lr_ratio)
* 0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
+ min_lr_ratio
)
def get_cosine_schedule_with_warmup_decay_constant(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
constant_lr_ratio: (`float`):
The ratio of num_training_steps to decrease by cosine function.
min_lr_ratio: (`float):
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
constant_lr_ratio=constant_lr_ratio,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

52
tests/test_schedulers.py Normal file
View File

@@ -0,0 +1,52 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
import torch
from torch.optim import SGD
from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant
class TestCosineConstantLr(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.train_steps = 1000
self.warmup_steps = 10
self.min_lr_ratio = 0.1
self.constant_lr_ratio = 0.8
self._lr = 0.01
self.optimizer = SGD([torch.tensor(1)], lr=self._lr)
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.train_steps,
min_lr_ratio=self.min_lr_ratio,
constant_lr_ratio=self.constant_lr_ratio,
)
def test_schedulers(self):
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
for _ in range(self.warmup_steps):
self.lr_scheduler.step()
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
constant_step = int(self.train_steps * self.constant_lr_ratio)
remaining_step = self.train_steps - constant_step
for _ in range(constant_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
for _ in range(remaining_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
if __name__ == "__main__":
unittest.main()