Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
d465b9fd98 wip, jagged restarts 2024-02-16 14:34:08 -05:00
45 changed files with 210 additions and 509 deletions

View File

@@ -12,6 +12,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"

View File

@@ -13,6 +13,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
@@ -68,6 +73,11 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"

View File

@@ -69,7 +69,7 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.0.1
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"

View File

@@ -34,7 +34,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Inference](#inference-playground) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens) - [Special Tokens](#special-tokens)
- Advanced Topics - Advanced Topics
@@ -734,8 +734,6 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
@@ -784,8 +782,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -814,7 +811,6 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim # For one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # Learning rate div factor

View File

@@ -2,6 +2,7 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -177,24 +177,6 @@
"# Buy using the ! the comand will be executed as a bash command\n", "# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
} }
], ],
"metadata": { "metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
# enable 4bit for QLoRA # enable 4bit for QLoRA
load_in_4bit: true load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

View File

@@ -1,65 +0,0 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false

View File

@@ -1,4 +1,5 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_disable_exllama: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -59,7 +60,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -56,7 +57,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -60,7 +61,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3 #default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -48,7 +49,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed_configs/zero2.json deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
@@ -67,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: true load_in_8bit: true
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -2,6 +2,7 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,8 @@
base_model: 01-ai/Yi-34B-Chat base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -28,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1 val_set_size: 0.1
evals_per_epoch: 5 evals_per_epoch: 5
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
eval_sample_packing: false eval_sample_packing: false
eval_batch_size: 1 eval_batch_size: 1

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632 transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.26.1 accelerate==0.26.1
@@ -11,7 +11,7 @@ fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets>=2.15.0 datasets>=2.15.0
flash-attn==2.5.5 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
einops einops
@@ -23,7 +23,7 @@ numba
numpy>=1.24.4 numpy>=1.24.4
mlflow mlflow
# qlora things # qlora things
evaluate==0.4.1 evaluate==0.4.0
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml

View File

@@ -67,7 +67,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.5", "flash-attn==2.5.0",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",

View File

@@ -38,7 +38,6 @@ from axolotl.utils.callbacks import (
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.collators import ( from axolotl.utils.collators import (
@@ -50,8 +49,7 @@ from axolotl.utils.collators import (
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
get_cosine_schedule_with_warmup_decay_constant,
) )
try: try:
@@ -131,11 +129,19 @@ class AxolotlTrainingArguments(TrainingArguments):
) )
relora_anneal_steps: Optional[int] = field( relora_anneal_steps: Optional[int] = field(
default=None, default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
) )
relora_prune_ratio: Optional[float] = field( jagged_restart_steps: Optional[int] = field(
default=0.9, default=None,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
) )
bench_split: Optional[str] = field( bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"} default="eval", metadata={"help": "The benchmark split to run on"}
@@ -149,9 +155,6 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field( do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."} default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
) )
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field( max_bench_samples: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
@@ -169,12 +172,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None, default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
) )
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
@@ -232,16 +229,6 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -251,7 +238,7 @@ class AxolotlTrainer(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio, min_lr_ratio=self.args.cosine_min_lr_ratio,
) )
else: else:
return super().create_scheduler(num_training_steps, optimizer) super().create_scheduler(num_training_steps, optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -259,6 +246,21 @@ class AxolotlTrainer(Trainer):
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
@@ -668,11 +670,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience: if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(
@@ -821,8 +818,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset: if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs[
"metric_for_best_model" "metric_for_best_model"
@@ -883,10 +878,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience or self.cfg.early_stopping_patience
) )
and ( and not self.cfg.test_datasets
(not self.cfg.test_datasets and self.cfg.val_set_size > 0) and self.cfg.val_set_size > 0
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps and self.cfg.save_steps
and self.cfg.eval_steps and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0 and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -907,6 +900,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )
if self.cfg.save_only_model:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler self.cfg.lr_scheduler
if self.cfg.lr_scheduler if self.cfg.lr_scheduler
@@ -917,9 +912,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -937,20 +929,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
if self.cfg.relora_steps: training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs[ training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
) )

View File

@@ -275,9 +275,7 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -427,9 +425,7 @@ def flashattn_forward(
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -692,9 +688,6 @@ def llama_model_forward(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = ( output_attentions = (
output_attentions output_attentions

View File

@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"] SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
def patch_for_multipack(model_type): def patch_for_multipack(model_type):
@@ -28,7 +28,3 @@ def patch_for_multipack(model_type):
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -46,9 +46,8 @@ def reset_optimizer(
*, *,
reset_params: list[str], # where str is the key to a torch.nn.Parameter reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str], optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0 n_zeros = 0
n_total = 0 n_total = 0
@@ -160,7 +159,6 @@ class ReLoRACallback(TrainerCallback):
optimizer, optimizer,
reset_params=lora_params, reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys, optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
) )
if self.quantized: if self.quantized:

View File

@@ -0,0 +1,67 @@
from typing import Optional, Dict, Any
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation, truncation=True, max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids)
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(ds_cfg["conversation"]),
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -62,6 +62,7 @@ class EvalFirstStepCallback(
): ):
if ( if (
args.evaluation_strategy == IntervalStrategy.STEPS args.evaluation_strategy == IntervalStrategy.STEPS
and (args.eval_steps < 1.0 or args.eval_steps > 1)
and state.global_step == 1 and state.global_step == 1
): ):
control.should_evaluate = True control.should_evaluate = True
@@ -360,187 +361,6 @@ def bench_eval_callback_factory(trainer, tokenizer):
return BenchEvalCallback return BenchEvalCallback
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score
def evaluate_preds(sources, predictions, references):
scores = {}
for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer): def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback): class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation""" """Callback to log prediction values during each evaluation"""
@@ -568,7 +388,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
generation_config = GenerationConfig( generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens, max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id, bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,

View File

@@ -56,13 +56,7 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0 cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128 cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp: if cfg.ddp:
@@ -556,21 +550,6 @@ def validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer: if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}") raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -1,6 +1,7 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math import math
from functools import partial from functools import partial
from typing import Sequence
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -52,7 +53,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
*, *,
num_warmup_steps: int, num_warmup_steps: int,
num_training_steps: int, num_training_steps: int,
num_cycles: float, num_cycles: float
): ):
if current_step < num_warmup_steps: if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
@@ -107,7 +108,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
*, *,
num_warmup_steps: int, num_warmup_steps: int,
num_training_steps: int, num_training_steps: int,
min_lr_ratio: float, min_lr_ratio: float
): ):
# Warm up # Warm up
if current_step < num_warmup_steps: if current_step < num_warmup_steps:
@@ -142,78 +143,46 @@ def get_cosine_schedule_with_min_lr(
return LambdaLR(optimizer, lr_lambda) return LambdaLR(optimizer, lr_lambda)
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda( class JaggedLRRestartScheduler(LRScheduler):
current_step: int, """Wraps another scheduler to apply per-lora-restart learning rate warmups."""
*,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_constant_steps = int(num_training_steps * constant_lr_ratio) def __init__(
current_step = min(current_step, num_constant_steps) self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
progress = float(current_step - num_warmup_steps) / float( def get_lr(self) -> float:
max(1, num_constant_steps - num_warmup_steps) self.inner_schedule.last_epoch = self.last_epoch
)
return ( original = self.inner_schedule.get_lr()
max( step = self.last_epoch
0,
(1 - min_lr_ratio)
* 0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
+ min_lr_ratio
)
if step < self.restarts_steps:
scale = 1
else:
per_relora_progress = step % self.restarts_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
def get_cosine_schedule_with_warmup_decay_constant( if isinstance(original, Sequence):
optimizer: Optimizer, return [lr * scale for lr in original]
num_warmup_steps: int, return original * scale
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
constant_lr_ratio: (`float`):
The ratio of num_training_steps to decrease by cosine function.
min_lr_ratio: (`float):
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
constant_lr_ratio=constant_lr_ratio,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -1,52 +0,0 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
import torch
from torch.optim import SGD
from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant
class TestCosineConstantLr(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.train_steps = 1000
self.warmup_steps = 10
self.min_lr_ratio = 0.1
self.constant_lr_ratio = 0.8
self._lr = 0.01
self.optimizer = SGD([torch.tensor(1)], lr=self._lr)
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.train_steps,
min_lr_ratio=self.min_lr_ratio,
constant_lr_ratio=self.constant_lr_ratio,
)
def test_schedulers(self):
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
for _ in range(self.warmup_steps):
self.lr_scheduler.step()
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
constant_step = int(self.train_steps * self.constant_lr_ratio)
remaining_step = self.train_steps - constant_step
for _ in range(constant_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
for _ in range(remaining_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
if __name__ == "__main__":
unittest.main()