Compare commits
9 Commits
refactor-f
...
unsloth_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f9b172c47 | ||
|
|
8671ed5a0c | ||
|
|
538c004080 | ||
|
|
add3b139ed | ||
|
|
a581e9f8f6 | ||
|
|
992e742cdc | ||
|
|
a1da39cd48 | ||
|
|
58ec8b1113 | ||
|
|
476a205cea |
14
README.md
14
README.md
@@ -612,6 +612,12 @@ eval_sample_packing:
|
|||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
|
# Passed through to transformers when loading the model when launched without accelerate
|
||||||
|
# Use `sequential` when training w/ model parallelism to limit memory
|
||||||
|
device_map:
|
||||||
|
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||||
|
max_memory:
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
@@ -659,7 +665,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id: # Set the name of your wandb run
|
wandb_name: # Set the name of your wandb run
|
||||||
|
wandb_run_id: # Set the ID of your wandb run
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -694,6 +701,9 @@ max_steps:
|
|||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
|
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||||
|
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# Save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
@@ -952,7 +962,7 @@ wandb_mode:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -24,16 +24,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupDecayLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear",
|
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -28,16 +28,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupDecayLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear",
|
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -32,16 +32,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupDecayLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear",
|
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -62,6 +62,9 @@ logging_steps: 1
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
@@ -430,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -643,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
|
|||||||
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
168
src/axolotl/monkeypatch/cross_entropy.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# Adapted from Unsloth
|
||||||
|
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MAX_FUSED_SIZE = 65536
|
||||||
|
|
||||||
|
def calculate_settings(n):
|
||||||
|
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||||
|
# CUDA only supports 65536 - 2^16 threads per block
|
||||||
|
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||||
|
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||||
|
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||||
|
num_warps = 4
|
||||||
|
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||||
|
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||||
|
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||||
|
return BLOCK_SIZE, num_warps
|
||||||
|
pass
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||||
|
loss_ptr,
|
||||||
|
lse_ptr,
|
||||||
|
labels_ptr,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,):
|
||||||
|
"""
|
||||||
|
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||||
|
Pi = exp(xi) / sum(exp(xi))
|
||||||
|
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||||
|
= -y [ x - log[sum(exp(x))] ]
|
||||||
|
= y * (log[sum(exp(x))] - x)
|
||||||
|
If y == 0: CE_i = 0
|
||||||
|
If y == 1: CE_i = logsumexp - x
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0)
|
||||||
|
logits_ptr += row_idx * logits_row_stride
|
||||||
|
loss_ptr += row_idx
|
||||||
|
lse_ptr += row_idx
|
||||||
|
labels_ptr += row_idx
|
||||||
|
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
# TODO: Fixup int32 locations to int64
|
||||||
|
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||||
|
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||||
|
max_logits = tl.max(logits, 0)
|
||||||
|
# Maximum stops overflow
|
||||||
|
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||||
|
tl.store(lse_ptr, lse)
|
||||||
|
|
||||||
|
if label_idx != -100:
|
||||||
|
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||||
|
loss = lse - logits_label
|
||||||
|
else:
|
||||||
|
loss = 0.0
|
||||||
|
tl.store(loss_ptr, loss)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||||
|
dloss_ptr, dloss_row_stride,
|
||||||
|
lse_ptr,
|
||||||
|
labels_ptr,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,):
|
||||||
|
"""
|
||||||
|
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||||
|
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||||
|
|
||||||
|
From https://en.wikipedia.org/wiki/LogSumExp
|
||||||
|
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||||
|
|
||||||
|
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||||
|
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||||
|
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||||
|
|
||||||
|
If y == 0: dC/dx = 0
|
||||||
|
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||||
|
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0)
|
||||||
|
logits_ptr += row_idx * logits_row_stride
|
||||||
|
dloss_ptr += row_idx * dloss_row_stride
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
# TODO: Fixup int32 locations to int64
|
||||||
|
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||||
|
|
||||||
|
if label_idx != -100:
|
||||||
|
dloss = tl.load(dloss_ptr)
|
||||||
|
else:
|
||||||
|
dloss = 0.0
|
||||||
|
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||||
|
lse = tl.load(lse_ptr + row_idx)
|
||||||
|
probs = tl.exp(logits - lse)
|
||||||
|
|
||||||
|
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||||
|
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyLoss(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, logits, labels):
|
||||||
|
n_rows, n_cols = logits.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||||
|
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||||
|
|
||||||
|
_cross_entropy_forward[(n_rows,)](
|
||||||
|
logits, logits.stride(0),
|
||||||
|
losses,
|
||||||
|
logsumexp,
|
||||||
|
labels,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE = BLOCK_SIZE,
|
||||||
|
num_warps = num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.save_for_backward(logits, logsumexp, labels)
|
||||||
|
return losses
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dlosses):
|
||||||
|
logits, logsumexp, labels = ctx.saved_tensors
|
||||||
|
n_rows, n_cols = logits.shape
|
||||||
|
|
||||||
|
_cross_entropy_backward[(n_rows,)](
|
||||||
|
logits, logits.stride(0),
|
||||||
|
dlosses, dlosses.stride(0),
|
||||||
|
logsumexp,
|
||||||
|
labels,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||||
|
num_warps = ctx.num_warps,
|
||||||
|
)
|
||||||
|
return logits, None, None,
|
||||||
|
pass
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def fast_cross_entropy_loss(logits, labels):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
logits: (batch, seq_len, vocab_size)
|
||||||
|
labels: (batch, seq_len,)
|
||||||
|
Returns:
|
||||||
|
losses: float
|
||||||
|
"""
|
||||||
|
batch, seq_len, d = logits.shape
|
||||||
|
assert(labels.shape == (batch, seq_len))
|
||||||
|
|
||||||
|
loss = CrossEntropyLoss.apply(
|
||||||
|
logits.view(batch*seq_len, d),
|
||||||
|
labels.view(-1),
|
||||||
|
)
|
||||||
|
n_items = torch.count_nonzero(labels != -100)
|
||||||
|
return loss.sum() / n_items
|
||||||
|
pass
|
||||||
@@ -1,426 +0,0 @@
|
|||||||
import torch
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from einops import rearrange
|
|
||||||
from functools import partial
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from axolotl.monkeypatch.fused_modules import FusedAttention
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
|
||||||
)
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
def flashattn_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel
|
|
||||||
|
|
||||||
attention_mask: [bsz, q_len]
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
|
||||||
self.pretraining_tp = 1
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (
|
|
||||||
self.num_key_value_heads * self.head_dim
|
|
||||||
) // self.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [
|
|
||||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [
|
|
||||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [
|
|
||||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(self, FusedAttention):
|
|
||||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
|
||||||
self.out_features, dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = self.apply_rotary_fn(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
hasattr(self.config, "sliding_window") is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_sliding_windows:
|
|
||||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
|
||||||
else:
|
|
||||||
window_size = (-1, -1)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
|
||||||
if (
|
|
||||||
hasattr(self.config, "sliding_window")
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
):
|
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
|
||||||
past_value = past_key_value[1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value = (past_key, past_value) if use_cache else None
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = self.repeat_kv_fn(key_states, self.num_key_value_groups)
|
|
||||||
value_states = self.repeat_kv_fn(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
warnings.warn(
|
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 start
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# during training q,k,v always have same seqlen
|
|
||||||
assert key_states.shape == query_states.shape
|
|
||||||
is_causal = True
|
|
||||||
else:
|
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
|
||||||
is_causal = key_states.shape == query_states.shape
|
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
elif query_states.shape == key_states.shape:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
qkvpacked=True,
|
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
max_seqlen_q,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
if attention_mask is None or attention_mask.all().item():
|
|
||||||
output = flash_attn_kvpacked_func(
|
|
||||||
query_states,
|
|
||||||
torch.stack([key_states, value_states], 2),
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
output_pad_fn,
|
|
||||||
) = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
kvpacked=True,
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
|
|
||||||
attn_output = output
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
|
||||||
self.hidden_size // self.pretraining_tp, dim=1
|
|
||||||
)
|
|
||||||
attn_output = sum(
|
|
||||||
F.linear(attn_output[i], o_proj_slices[i])
|
|
||||||
for i in range(self.pretraining_tp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
|
||||||
def generate_qkv(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
query_padding_mask=None,
|
|
||||||
key_padding_mask=None,
|
|
||||||
kvpacked=False,
|
|
||||||
qkvpacked=False,
|
|
||||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch_size, seqlen_q, nheads, d)
|
|
||||||
k: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
v: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
query_padding_mask: (batch_size, seqlen), bool
|
|
||||||
key_padding_mask: (batch_size, seqlen), bool
|
|
||||||
"""
|
|
||||||
assert not (kvpacked and qkvpacked)
|
|
||||||
batch_size, seqlen_q, nheads, d = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
|
||||||
q, query_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seqlen_q
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
|
||||||
else:
|
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=k_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_k = seqlen_k
|
|
||||||
|
|
||||||
if qkvpacked:
|
|
||||||
assert nheads == nheads_k
|
|
||||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
|
||||||
|
|
||||||
if kvpacked:
|
|
||||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
kv,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
k_unpad,
|
|
||||||
v_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_cross_entropy(modeling_class, module_name):
|
|
||||||
"""
|
|
||||||
modeling_class: transformers.models.llama.modeling_<class>
|
|
||||||
module_name: CrossEntropyLoss
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
|
|
||||||
cross_entropy_loss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(modeling_class, module_name, cross_entropy_loss)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_rms_norm(modeling_class, module_name):
|
|
||||||
"""
|
|
||||||
modeling_class: transformers.models.llama.modeling_<class>
|
|
||||||
module_name: RMSNorm
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class FlashRMSNorm(RMSNorm):
|
|
||||||
"""A faster RMS Norm."""
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
setattr(modeling_class, module_name, FlashRMSNorm)
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import List
|
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
from axolotl.monkeypatch.utils import set_module_name
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
LlamaMLP,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Generalize to other attention modules
|
|
||||||
class FusedAttention(LlamaAttention):
|
|
||||||
"""
|
|
||||||
Fused QKV Attention layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
self.init_device = next(iter(q.state_dict().values())).device
|
|
||||||
|
|
||||||
# define equivalent fused qkv projection
|
|
||||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
|
||||||
self.qkv_proj = torch.nn.Linear(
|
|
||||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
|
||||||
)
|
|
||||||
self.o_proj = o
|
|
||||||
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.qkv_proj.weight.data = torch.cat(
|
|
||||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
q_proj, k_proj, v_proj = torch.split(
|
|
||||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
new_attn = LlamaAttention(self.config)
|
|
||||||
new_attn.q_proj.weight.data = q_proj
|
|
||||||
new_attn.k_proj.weight.data = k_proj
|
|
||||||
new_attn.v_proj.weight.data = v_proj
|
|
||||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_attn)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMLP(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
gate_proj: torch.nn.Linear,
|
|
||||||
up_proj: torch.nn.Linear,
|
|
||||||
down_proj: torch.nn.Linear,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.swiglu = SwiGLU(
|
|
||||||
in_features=config.hidden_size,
|
|
||||||
hidden_features=config.intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
_pack_weights=True,
|
|
||||||
)
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.swiglu.w12.weight.data = torch.cat(
|
|
||||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
|
||||||
)
|
|
||||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
|
||||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the split weights back to the original layers
|
|
||||||
new_mlp = LlamaMLP(self.config)
|
|
||||||
new_mlp.gate_proj.weight.data = w1
|
|
||||||
new_mlp.up_proj.weight.data = w2
|
|
||||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_mlp)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
|
||||||
return self.swiglu(x)
|
|
||||||
@@ -3,10 +3,15 @@
|
|||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
@@ -14,20 +19,27 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
)
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
|
||||||
from axolotl.monkeypatch.flash_modules import (
|
try:
|
||||||
flashattn_forward,
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
replace_cross_entropy,
|
flash_attn_kvpacked_func,
|
||||||
replace_rms_norm
|
flash_attn_varlen_kvpacked_func,
|
||||||
)
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||||
|
)
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -63,17 +75,129 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.apply_rotary_fn = apply_rotary_pos_emb
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.repeat_kv_fn = repeat_kv
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
try:
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm")
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(LlamaAttention):
|
||||||
|
"""
|
||||||
|
Fused QKV Attention layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.init_device = next(iter(q.state_dict().values())).device
|
||||||
|
|
||||||
|
# define equivalent fused qkv projection
|
||||||
|
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||||
|
self.qkv_proj = torch.nn.Linear(
|
||||||
|
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = o
|
||||||
|
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.qkv_proj.weight.data = torch.cat(
|
||||||
|
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
q_proj, k_proj, v_proj = torch.split(
|
||||||
|
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
new_attn = LlamaAttention(self.config)
|
||||||
|
new_attn.q_proj.weight.data = q_proj
|
||||||
|
new_attn.k_proj.weight.data = k_proj
|
||||||
|
new_attn.v_proj.weight.data = v_proj
|
||||||
|
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -89,6 +213,330 @@ def _prepare_decoder_attention_mask(
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel
|
||||||
|
|
||||||
|
attention_mask: [bsz, q_len]
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, "pretraining_tp"):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (
|
||||||
|
self.num_key_value_heads * self.head_dim
|
||||||
|
) // self.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [
|
||||||
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [
|
||||||
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [
|
||||||
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if isinstance(self, FusedAttention):
|
||||||
|
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||||
|
self.out_features, dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
warnings.warn(
|
||||||
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 start
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 end
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
|
self.hidden_size // self.pretraining_tp, dim=1
|
||||||
|
)
|
||||||
|
attn_output = sum(
|
||||||
|
F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.pretraining_tp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -6,37 +6,33 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralAttention as OriginalMistralAttention,
|
||||||
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
MistralMLP
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralForCausalLM as OriginalMistralForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
||||||
from axolotl.monkeypatch.flash_modules import (
|
|
||||||
flashattn_forward,
|
|
||||||
replace_cross_entropy,
|
|
||||||
replace_rms_norm
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.fused_modules import FusedMLP
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
def replace_mistral_mlp_with_swiglu(model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, MistralMLP):
|
|
||||||
mlp = FusedMLP(
|
|
||||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
|
||||||
)
|
|
||||||
set_module_name(model, name, mlp)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_attn_with_flash_attn(
|
def replace_mistral_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
|
||||||
rms_norm: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
@@ -44,8 +40,9 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
flashattn_forward
|
flashattn_forward
|
||||||
)
|
)
|
||||||
transformers.models.mistral.modeling_mistral.MistralAttention.apply_rotary_fn = apply_rotary_pos_emb
|
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
||||||
transformers.models.mistral.modeling_mistral.MistralAttention.repeat_kv_fn = repeat_kv
|
mistral_causallm_forward
|
||||||
|
)
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
MistralDecoderLayer
|
MistralDecoderLayer
|
||||||
@@ -53,10 +50,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||||
mistral_model_forward
|
mistral_model_forward
|
||||||
)
|
)
|
||||||
if cross_entropy:
|
|
||||||
replace_cross_entropy(transformers.mistral.llama.modeling_mistral, "CrossEntropyLoss")
|
|
||||||
if rms_norm:
|
|
||||||
replace_rms_norm(transformers.mistral.llama.modeling_mistral, "MistralRMSNorm")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -129,6 +122,302 @@ def _prepare_decoder_attention_mask(
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self: OriginalMistralAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
use_sliding_windows = (
|
||||||
|
hasattr(self.config, "sliding_window") is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_sliding_windows:
|
||||||
|
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||||
|
else:
|
||||||
|
window_size = (-1, -1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
if (
|
||||||
|
hasattr(self.config, "sliding_window")
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
):
|
||||||
|
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[0]
|
||||||
|
past_value = past_key_value[1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value = (past_key, past_value) if use_cache else None
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mistral_model_forward(
|
def mistral_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@@ -359,3 +648,71 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
|||||||
outputs += (present_key_value,)
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def mistral_causallm_forward(
|
||||||
|
self: OriginalMistralForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
*args, **kwargs
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
shift_logits = logits
|
||||||
|
if not hasattr(self, "extra_ignored_labels"):
|
||||||
|
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||||
|
|
||||||
|
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
|
||||||
|
# FAST CROSS ENTROPY
|
||||||
|
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
@@ -124,6 +124,36 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class LossWatchDogCallback(TrainerCallback):
|
||||||
|
"""Callback to track loss and stop training if loss is too high"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logged = False
|
||||||
|
self.violations = 0
|
||||||
|
self.threshold = cfg.loss_watchdog_threshold
|
||||||
|
self.patience = cfg.loss_watchdog_patience or 3
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
_args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||||
|
if state.log_history[-1]["loss"] > self.threshold:
|
||||||
|
self.violations += 1
|
||||||
|
if self.violations >= self.patience:
|
||||||
|
LOG.warning(
|
||||||
|
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||||
|
)
|
||||||
|
control.should_training_stop = True
|
||||||
|
else:
|
||||||
|
self.violations = 0
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.world_size == 1:
|
if cfg.world_size == 1:
|
||||||
cfg.device_map = "auto"
|
cfg.device_map = cfg.device_map or "auto"
|
||||||
else:
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": torch.cuda.current_device()}
|
cfg.device_map = {"": torch.cuda.current_device()}
|
||||||
@@ -397,6 +397,13 @@ def validate_config(cfg):
|
|||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||||
|
cfg.wandb_name = cfg.wandb_run_id
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -28,6 +28,27 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||||
|
quant_config_exists = hasattr(model_config, "quantization_config")
|
||||||
|
quant_config_method_is_gptq = (
|
||||||
|
quant_config_exists
|
||||||
|
and "quant_method" in model_config.quantization_config
|
||||||
|
and model_config.quantization_config["quant_method"] == "gptq"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.gptq and not quant_config_method_is_gptq:
|
||||||
|
raise ValueError(
|
||||||
|
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
||||||
|
"Please make sure to point to a GPTQ model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cfg.gptq and quant_config_exists:
|
||||||
|
raise ValueError(
|
||||||
|
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||||
|
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
@@ -38,6 +59,8 @@ def load_model_config(cfg):
|
|||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
|
|
||||||
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
@@ -193,11 +216,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
packed=cfg.sample_packing,
|
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
@@ -220,6 +239,7 @@ def load_model(
|
|||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
model_kwargs["device_map"] = cfg.device_map
|
model_kwargs["device_map"] = cfg.device_map
|
||||||
|
model_kwargs["max_memory"] = cfg.max_memory
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
@@ -278,15 +298,6 @@ def load_model(
|
|||||||
if cfg.flash_attn_fuse_qkv:
|
if cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("patching with fused QKV")
|
LOG.info("patching with fused QKV")
|
||||||
replace_llama_qkv_with_fused(model)
|
replace_llama_qkv_with_fused(model)
|
||||||
elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
||||||
if cfg.flash_attention and not inference:
|
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
|
||||||
replace_mistral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_mlp:
|
|
||||||
LOG.info("patching with SwiGLU")
|
|
||||||
replace_mistral_mlp_with_swiglu(model)
|
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
# This is a WIP, still an issue with the backward pass
|
# This is a WIP, still an issue with the backward pass
|
||||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||||
|
|||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
def setup_wandb_env_vars(cfg):
|
|
||||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
def setup_wandb_env_vars(cfg: DictDefault):
|
||||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
for key in cfg.keys():
|
||||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
if key.startswith("wandb_"):
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value and isinstance(value, str) and len(value) > 0:
|
||||||
|
os.environ[key.upper()] = value
|
||||||
|
|
||||||
|
# Enable wandb if project name is present
|
||||||
|
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
||||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
|
||||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
|
||||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
|
||||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
|
||||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
|
||||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
|
||||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
|
||||||
else:
|
else:
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module for testing the validation module"""
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ import pytest
|
|||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationWandbTest(ValidationTest):
|
||||||
|
"""
|
||||||
|
Validation test for wandb
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_wandb_set_run_id_to_name(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_run_id": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_name": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
||||||
|
|
||||||
|
def test_wandb_sets_env(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_project": "foo",
|
||||||
|
"wandb_name": "bar",
|
||||||
|
"wandb_run_id": "bat",
|
||||||
|
"wandb_entity": "baz",
|
||||||
|
"wandb_mode": "online",
|
||||||
|
"wandb_watch": "false",
|
||||||
|
"wandb_log_model": "checkpoint",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
||||||
|
assert os.environ.get("WANDB_NAME", "") == "bar"
|
||||||
|
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
||||||
|
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
||||||
|
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||||
|
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||||
|
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||||
|
|
||||||
|
def test_wandb_set_disabled(self):
|
||||||
|
cfg = DictDefault({})
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_project": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||||
|
|||||||
Reference in New Issue
Block a user