Compare commits
8 Commits
soap-optim
...
1991test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfb80a3ef9 | ||
|
|
38773d661f | ||
|
|
271c2c2b82 | ||
|
|
32b6f30947 | ||
|
|
fc1f275e6c | ||
|
|
46d2b4ce89 | ||
|
|
88c9a7aecc | ||
|
|
d9a93990d1 |
295
1991.yml
Normal file
295
1991.yml
Normal file
@@ -0,0 +1,295 @@
|
||||
base_model: Qwen/Qwen2.5-14B-Instruct
|
||||
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
chat_template: chatml
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
unfrozen_parameters:
|
||||
- ^lm_head.weight$
|
||||
- ^model.embed_tokens.weight$
|
||||
# input_layernorm layers
|
||||
- model.layers.0.input_layernorm
|
||||
- model.layers.1.input_layernorm
|
||||
- model.layers.2.input_layernorm
|
||||
- model.layers.3.input_layernorm
|
||||
- model.layers.4.input_layernorm
|
||||
- model.layers.5.input_layernorm
|
||||
- model.layers.6.input_layernorm
|
||||
- model.layers.7.input_layernorm
|
||||
- model.layers.8.input_layernorm
|
||||
- model.layers.9.input_layernorm
|
||||
- model.layers.10.input_layernorm
|
||||
- model.layers.11.input_layernorm
|
||||
- model.layers.12.input_layernorm
|
||||
- model.layers.13.input_layernorm
|
||||
- model.layers.14.input_layernorm
|
||||
- model.layers.15.input_layernorm
|
||||
- model.layers.16.input_layernorm
|
||||
- model.layers.17.input_layernorm
|
||||
- model.layers.18.input_layernorm
|
||||
- model.layers.19.input_layernorm
|
||||
- model.layers.20.input_layernorm
|
||||
- model.layers.21.input_layernorm
|
||||
- model.layers.22.input_layernorm
|
||||
- model.layers.23.input_layernorm
|
||||
# lm_head layers
|
||||
# mlp.down_proj layers
|
||||
- model.layers.1.mlp.down_proj
|
||||
- model.layers.35.mlp.down_proj
|
||||
- model.layers.38.mlp.down_proj
|
||||
- model.layers.37.mlp.down_proj
|
||||
- model.layers.36.mlp.down_proj
|
||||
- model.layers.15.mlp.down_proj
|
||||
- model.layers.11.mlp.down_proj
|
||||
- model.layers.12.mlp.down_proj
|
||||
- model.layers.34.mlp.down_proj
|
||||
- model.layers.44.mlp.down_proj
|
||||
- model.layers.45.mlp.down_proj
|
||||
- model.layers.9.mlp.down_proj
|
||||
- model.layers.41.mlp.down_proj
|
||||
- model.layers.33.mlp.down_proj
|
||||
- model.layers.43.mlp.down_proj
|
||||
- model.layers.40.mlp.down_proj
|
||||
- model.layers.13.mlp.down_proj
|
||||
- model.layers.8.mlp.down_proj
|
||||
- model.layers.39.mlp.down_proj
|
||||
- model.layers.10.mlp.down_proj
|
||||
- model.layers.14.mlp.down_proj
|
||||
- model.layers.16.mlp.down_proj
|
||||
- model.layers.31.mlp.down_proj
|
||||
- model.layers.32.mlp.down_proj
|
||||
# mlp.gate_proj layers
|
||||
- model.layers.1.mlp.gate_proj
|
||||
- model.layers.44.mlp.gate_proj
|
||||
- model.layers.46.mlp.gate_proj
|
||||
- model.layers.45.mlp.gate_proj
|
||||
- model.layers.43.mlp.gate_proj
|
||||
- model.layers.47.mlp.gate_proj
|
||||
- model.layers.42.mlp.gate_proj
|
||||
- model.layers.32.mlp.gate_proj
|
||||
- model.layers.27.mlp.gate_proj
|
||||
- model.layers.33.mlp.gate_proj
|
||||
- model.layers.28.mlp.gate_proj
|
||||
- model.layers.39.mlp.gate_proj
|
||||
- model.layers.41.mlp.gate_proj
|
||||
- model.layers.40.mlp.gate_proj
|
||||
- model.layers.30.mlp.gate_proj
|
||||
- model.layers.29.mlp.gate_proj
|
||||
- model.layers.31.mlp.gate_proj
|
||||
- model.layers.26.mlp.gate_proj
|
||||
- model.layers.37.mlp.gate_proj
|
||||
- model.layers.10.mlp.gate_proj
|
||||
- model.layers.38.mlp.gate_proj
|
||||
- model.layers.12.mlp.gate_proj
|
||||
- model.layers.36.mlp.gate_proj
|
||||
- model.layers.13.mlp.gate_proj
|
||||
# mlp.up_proj layers
|
||||
- model.layers.1.mlp.up_proj
|
||||
- model.layers.13.mlp.up_proj
|
||||
- model.layers.11.mlp.up_proj
|
||||
- model.layers.14.mlp.up_proj
|
||||
- model.layers.15.mlp.up_proj
|
||||
- model.layers.12.mlp.up_proj
|
||||
- model.layers.8.mlp.up_proj
|
||||
- model.layers.16.mlp.up_proj
|
||||
- model.layers.9.mlp.up_proj
|
||||
- model.layers.19.mlp.up_proj
|
||||
- model.layers.10.mlp.up_proj
|
||||
- model.layers.7.mlp.up_proj
|
||||
- model.layers.17.mlp.up_proj
|
||||
- model.layers.20.mlp.up_proj
|
||||
- model.layers.21.mlp.up_proj
|
||||
- model.layers.18.mlp.up_proj
|
||||
- model.layers.38.mlp.up_proj
|
||||
- model.layers.37.mlp.up_proj
|
||||
- model.layers.39.mlp.up_proj
|
||||
- model.layers.42.mlp.up_proj
|
||||
- model.layers.41.mlp.up_proj
|
||||
- model.layers.27.mlp.up_proj
|
||||
- model.layers.28.mlp.up_proj
|
||||
- model.layers.34.mlp.up_proj
|
||||
# model.norm layers
|
||||
# post_attention_layernorm layers
|
||||
- model.layers.0.post_attention_layernorm
|
||||
- model.layers.1.post_attention_layernorm
|
||||
- model.layers.2.post_attention_layernorm
|
||||
- model.layers.3.post_attention_layernorm
|
||||
- model.layers.4.post_attention_layernorm
|
||||
- model.layers.5.post_attention_layernorm
|
||||
- model.layers.6.post_attention_layernorm
|
||||
- model.layers.7.post_attention_layernorm
|
||||
- model.layers.8.post_attention_layernorm
|
||||
- model.layers.9.post_attention_layernorm
|
||||
- model.layers.10.post_attention_layernorm
|
||||
- model.layers.11.post_attention_layernorm
|
||||
- model.layers.12.post_attention_layernorm
|
||||
- model.layers.13.post_attention_layernorm
|
||||
- model.layers.14.post_attention_layernorm
|
||||
- model.layers.15.post_attention_layernorm
|
||||
- model.layers.16.post_attention_layernorm
|
||||
- model.layers.17.post_attention_layernorm
|
||||
- model.layers.18.post_attention_layernorm
|
||||
- model.layers.19.post_attention_layernorm
|
||||
- model.layers.20.post_attention_layernorm
|
||||
- model.layers.21.post_attention_layernorm
|
||||
- model.layers.22.post_attention_layernorm
|
||||
- model.layers.23.post_attention_layernorm
|
||||
# self_attn.k_proj layers
|
||||
- model.layers.47.self_attn.k_proj
|
||||
- model.layers.39.self_attn.k_proj
|
||||
- model.layers.41.self_attn.k_proj
|
||||
- model.layers.37.self_attn.k_proj
|
||||
- model.layers.35.self_attn.k_proj
|
||||
- model.layers.44.self_attn.k_proj
|
||||
- model.layers.38.self_attn.k_proj
|
||||
- model.layers.14.self_attn.k_proj
|
||||
- model.layers.7.self_attn.k_proj
|
||||
- model.layers.12.self_attn.k_proj
|
||||
- model.layers.11.self_attn.k_proj
|
||||
- model.layers.32.self_attn.k_proj
|
||||
- model.layers.10.self_attn.k_proj
|
||||
- model.layers.8.self_attn.k_proj
|
||||
- model.layers.9.self_attn.k_proj
|
||||
- model.layers.6.self_attn.k_proj
|
||||
- model.layers.45.self_attn.k_proj
|
||||
- model.layers.42.self_attn.k_proj
|
||||
- model.layers.5.self_attn.k_proj
|
||||
- model.layers.40.self_attn.k_proj
|
||||
- model.layers.33.self_attn.k_proj
|
||||
- model.layers.0.self_attn.k_proj
|
||||
- model.layers.34.self_attn.k_proj
|
||||
- model.layers.13.self_attn.k_proj
|
||||
# self_attn.o_proj layers
|
||||
- model.layers.12.self_attn.o_proj
|
||||
- model.layers.5.self_attn.o_proj
|
||||
- model.layers.14.self_attn.o_proj
|
||||
- model.layers.16.self_attn.o_proj
|
||||
- model.layers.20.self_attn.o_proj
|
||||
- model.layers.13.self_attn.o_proj
|
||||
- model.layers.11.self_attn.o_proj
|
||||
- model.layers.4.self_attn.o_proj
|
||||
- model.layers.6.self_attn.o_proj
|
||||
- model.layers.19.self_attn.o_proj
|
||||
- model.layers.7.self_attn.o_proj
|
||||
- model.layers.18.self_attn.o_proj
|
||||
- model.layers.8.self_attn.o_proj
|
||||
- model.layers.38.self_attn.o_proj
|
||||
- model.layers.15.self_attn.o_proj
|
||||
- model.layers.17.self_attn.o_proj
|
||||
- model.layers.9.self_attn.o_proj
|
||||
- model.layers.10.self_attn.o_proj
|
||||
- model.layers.21.self_attn.o_proj
|
||||
- model.layers.28.self_attn.o_proj
|
||||
- model.layers.32.self_attn.o_proj
|
||||
- model.layers.35.self_attn.o_proj
|
||||
- model.layers.39.self_attn.o_proj
|
||||
- model.layers.3.self_attn.o_proj
|
||||
# self_attn.q_proj layers
|
||||
- model.layers.1.self_attn.q_proj
|
||||
- model.layers.2.self_attn.q_proj
|
||||
- model.layers.3.self_attn.q_proj
|
||||
- model.layers.44.self_attn.q_proj
|
||||
- model.layers.29.self_attn.q_proj
|
||||
- model.layers.45.self_attn.q_proj
|
||||
- model.layers.43.self_attn.q_proj
|
||||
- model.layers.32.self_attn.q_proj
|
||||
- model.layers.38.self_attn.q_proj
|
||||
- model.layers.19.self_attn.q_proj
|
||||
- model.layers.42.self_attn.q_proj
|
||||
- model.layers.34.self_attn.q_proj
|
||||
- model.layers.36.self_attn.q_proj
|
||||
- model.layers.40.self_attn.q_proj
|
||||
- model.layers.26.self_attn.q_proj
|
||||
- model.layers.20.self_attn.q_proj
|
||||
- model.layers.39.self_attn.q_proj
|
||||
- model.layers.28.self_attn.q_proj
|
||||
- model.layers.35.self_attn.q_proj
|
||||
- model.layers.41.self_attn.q_proj
|
||||
- model.layers.33.self_attn.q_proj
|
||||
- model.layers.25.self_attn.q_proj
|
||||
- model.layers.30.self_attn.q_proj
|
||||
- model.layers.27.self_attn.q_proj
|
||||
# self_attn.v_proj layers
|
||||
- model.layers.0.self_attn.v_proj
|
||||
- model.layers.7.self_attn.v_proj
|
||||
- model.layers.39.self_attn.v_proj
|
||||
- model.layers.31.self_attn.v_proj
|
||||
- model.layers.15.self_attn.v_proj
|
||||
- model.layers.10.self_attn.v_proj
|
||||
- model.layers.32.self_attn.v_proj
|
||||
- model.layers.41.self_attn.v_proj
|
||||
- model.layers.6.self_attn.v_proj
|
||||
- model.layers.33.self_attn.v_proj
|
||||
- model.layers.42.self_attn.v_proj
|
||||
- model.layers.29.self_attn.v_proj
|
||||
- model.layers.14.self_attn.v_proj
|
||||
- model.layers.9.self_attn.v_proj
|
||||
- model.layers.35.self_attn.v_proj
|
||||
- model.layers.38.self_attn.v_proj
|
||||
- model.layers.13.self_attn.v_proj
|
||||
- model.layers.30.self_attn.v_proj
|
||||
- model.layers.5.self_attn.v_proj
|
||||
- model.layers.34.self_attn.v_proj
|
||||
- model.layers.28.self_attn.v_proj
|
||||
- model.layers.37.self_attn.v_proj
|
||||
- model.layers.27.self_attn.v_proj
|
||||
- model.layers.11.self_attn.v_proj
|
||||
# model.embed_tokens layers
|
||||
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: linear
|
||||
learning_rate: 5e-6
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
gradient_checkpointing: unsloth
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
save_total_limit: 4
|
||||
debug:
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
weight_decay: 0.05
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
@@ -7,8 +7,8 @@ load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
- path: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
shards: 10
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
|
||||
### Background
|
||||
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
```
|
||||
|
||||
>[!Important]
|
||||
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```jsonc
|
||||
// .vscode/launch.json
|
||||
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug axolotl prompt - chat_template",
|
||||
"name": "Debug axolotl prompt - sharegpt",
|
||||
"type": "python",
|
||||
"module": "accelerate.commands.launch",
|
||||
"request": "launch",
|
||||
"args": [
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
@@ -16,10 +16,7 @@ chat_template: deepseek_v2
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -11,11 +11,8 @@ chat_template: gemma
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: gemma
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
|
||||
@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
use_tensorboard: true
|
||||
chat_template: jamba
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: jamba
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
|
||||
@@ -14,10 +14,6 @@ datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -10,6 +10,7 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: phi_3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
@@ -272,7 +272,7 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -435,13 +435,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"soap",
|
||||
]
|
||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
@@ -484,25 +478,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif self.args.alternate_optimizer == "soap":
|
||||
from axolotl.utils.optimizers.soap import SOAP
|
||||
|
||||
optim_args = {
|
||||
"lr": optimizer_kwargs.pop("lr"),
|
||||
"eps": optimizer_kwargs.pop("eps"),
|
||||
}
|
||||
|
||||
if self.cfg.optim_args:
|
||||
optim_args.update(self.cfg.optim_args)
|
||||
|
||||
optim_args["betas"] = (
|
||||
self.args.optim_soap_beta1,
|
||||
self.args.optim_soap_beta2,
|
||||
)
|
||||
self.optimizer = SOAP( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_grouped_parameters,
|
||||
**optim_args,
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
@@ -920,13 +895,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
def _save_checkpoint(self, model, trial):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||
return super()._save_checkpoint(model, trial)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
@@ -1620,8 +1595,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
self.cfg.chat_template
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
@@ -1638,12 +1612,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.optimizer in [
|
||||
# pylint: disable=duplicate-code
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"soap",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
|
||||
@@ -27,15 +27,18 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
if not has_remote_code:
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
return
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -57,7 +57,6 @@ class ChatTemplate(str, Enum):
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
@@ -427,7 +426,6 @@ class HyperparametersConfig(BaseModel):
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"soap",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
@@ -440,10 +438,6 @@ class HyperparametersConfig(BaseModel):
|
||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
},
|
||||
)
|
||||
|
||||
optim_soap_beta1: Optional[float] = None
|
||||
optim_soap_beta2: Optional[float] = None
|
||||
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -394,10 +394,15 @@ class ModelLoader:
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
has_remote_code = (
|
||||
"auto_map" in self.model_config
|
||||
and self.model_type in self.model_config["auto_map"]
|
||||
)
|
||||
|
||||
patch_for_multipack(
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
is_remote_code=self.cfg.trust_remote_code,
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
@@ -640,7 +645,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
elif self.cfg.adapter == "qlora" and (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -663,7 +670,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
||||
elif self.cfg.adapter == "lora" and (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -676,8 +685,10 @@ class ModelLoader:
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
self.model_kwargs.pop("load_in_4bit", None)
|
||||
if "load_in_8bit" in self.model_kwargs:
|
||||
del self.model_kwargs["load_in_8bit"]
|
||||
if "load_in_4bit" in self.model_kwargs:
|
||||
del self.model_kwargs["load_in_4bit"]
|
||||
|
||||
def set_attention_config(self) -> None:
|
||||
"""
|
||||
@@ -962,10 +973,17 @@ class ModelLoader:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
is_load_in_8bit = (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
)
|
||||
is_load_in_4bit = (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
)
|
||||
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
|
||||
and (is_load_in_8bit or is_load_in_4bit)
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
self.model = prepare_model_for_kbit_training(
|
||||
@@ -1103,10 +1121,16 @@ class ModelLoader:
|
||||
# ---------------------------------------------------------
|
||||
# put model to accelerator
|
||||
# ---------------------------------------------------------
|
||||
is_load_in_8bit = (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
)
|
||||
is_load_in_4bit = (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
)
|
||||
if (
|
||||
self.cfg.ddp
|
||||
and not self.cfg.load_in_8bit
|
||||
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
||||
and not is_load_in_8bit
|
||||
and not (self.cfg.rl and is_load_in_4bit)
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Nikhil Vyas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,475 +0,0 @@
|
||||
# pylint: skip-file
|
||||
# Copied from https://github.com/nikhilvyas/SOAP
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
# Parts of the code are modifications of Pytorch's AdamW optimizer
|
||||
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
|
||||
|
||||
|
||||
class SOAP(optim.Optimizer):
|
||||
"""
|
||||
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.003):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
|
||||
eps (`float`, *optional*, defaults to 1e-08):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||
precondition_frequency (`int`, *optional*, defaults to 10):
|
||||
How often to update the preconditioner.
|
||||
max_precond_dim (`int`, *optional*, defaults to 10000):
|
||||
Maximum dimension of the preconditioner.
|
||||
Set to 10000, so that we exclude most common vocab sizes while including layers.
|
||||
merge_dims (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to merge dimensions of the preconditioner.
|
||||
precondition_1d (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to precondition 1D gradients.
|
||||
normalize_grads (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to normalize gradients per layer.
|
||||
Helps at large precondition_frequency (~100 in our experiments),
|
||||
but hurts performance at small precondition_frequency (~10 in our experiments).
|
||||
data_format (`str`, *optional*, defaults to `channels_first`):
|
||||
Data format of the input for convolutional layers.
|
||||
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use bias correction in Adam.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 3e-3,
|
||||
betas=(0.95, 0.95),
|
||||
shampoo_beta: float = -1,
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
precondition_frequency: int = 10,
|
||||
max_precond_dim: int = 10000, #
|
||||
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
|
||||
precondition_1d: bool = False,
|
||||
normalize_grads: bool = False,
|
||||
data_format: str = "channels_first",
|
||||
correct_bias: bool = True,
|
||||
):
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"shampoo_beta": shampoo_beta,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"precondition_frequency": precondition_frequency,
|
||||
"max_precond_dim": max_precond_dim,
|
||||
"merge_dims": merge_dims,
|
||||
"precondition_1d": precondition_1d,
|
||||
"normalize_grads": normalize_grads,
|
||||
"correct_bias": correct_bias,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
self._data_format = data_format
|
||||
|
||||
def merge_dims(self, grad, max_precond_dim):
|
||||
"""
|
||||
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
|
||||
"""
|
||||
assert self._data_format in ["channels_first", "channels_last"]
|
||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||
grad = grad.permute(0, 3, 1, 2)
|
||||
shape = grad.shape
|
||||
new_shape = []
|
||||
|
||||
curr_shape = 1
|
||||
for sh in shape:
|
||||
temp_shape = curr_shape * sh
|
||||
if temp_shape > max_precond_dim:
|
||||
if curr_shape > 1:
|
||||
new_shape.append(curr_shape)
|
||||
curr_shape = sh
|
||||
else:
|
||||
new_shape.append(sh)
|
||||
curr_shape = 1
|
||||
else:
|
||||
curr_shape = temp_shape
|
||||
|
||||
if curr_shape > 1 or len(new_shape) == 0:
|
||||
new_shape.append(curr_shape)
|
||||
|
||||
new_grad = grad.reshape(new_shape)
|
||||
return new_grad
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
if "Q" not in state:
|
||||
self.init_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
precondition_frequency=group["precondition_frequency"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
shampoo_beta=(
|
||||
group["shampoo_beta"]
|
||||
if group["shampoo_beta"] >= 0
|
||||
else group["betas"][1]
|
||||
),
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
)
|
||||
self.update_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
)
|
||||
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||
|
||||
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
||||
grad_projected = self.project(
|
||||
grad,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).add_(
|
||||
grad_projected.square(), alpha=(1.0 - beta2)
|
||||
)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
||||
exp_avg_projected = self.project(
|
||||
exp_avg,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]:
|
||||
bias_correction1 = 1.0 - beta1 ** (state["step"])
|
||||
bias_correction2 = 1.0 - beta2 ** (state["step"])
|
||||
step_size = step_size * (bias_correction2**0.5) / bias_correction1
|
||||
|
||||
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
||||
# to the original space
|
||||
norm_grad = self.project_back(
|
||||
exp_avg_projected / denom,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
|
||||
if group["normalize_grads"]:
|
||||
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# From AdamW code: Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
# Update is done after the gradient step to avoid using current gradients in the projection.
|
||||
self.update_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def init_preconditioner(
|
||||
self,
|
||||
grad,
|
||||
state,
|
||||
precondition_frequency=10,
|
||||
shampoo_beta=0.95,
|
||||
max_precond_dim=10000,
|
||||
precondition_1d=False,
|
||||
merge_dims=False,
|
||||
):
|
||||
"""
|
||||
Initializes the preconditioner matrices (L and R in the paper).
|
||||
"""
|
||||
state[
|
||||
"GG"
|
||||
] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
||||
if grad.dim() == 1:
|
||||
if not precondition_1d or grad.shape[0] > max_precond_dim:
|
||||
state["GG"].append([])
|
||||
else:
|
||||
state["GG"].append(
|
||||
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
|
||||
)
|
||||
else:
|
||||
if merge_dims:
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
|
||||
for sh in grad.shape:
|
||||
if sh > max_precond_dim:
|
||||
state["GG"].append([])
|
||||
else:
|
||||
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
|
||||
|
||||
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
|
||||
state["precondition_frequency"] = precondition_frequency
|
||||
state["shampoo_beta"] = shampoo_beta
|
||||
|
||||
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||
"""
|
||||
Projects the gradient to the eigenbases of the preconditioner.
|
||||
"""
|
||||
original_shape = grad.shape
|
||||
if merge_dims:
|
||||
if grad.dim() == 4 and self._data_format == "channels_last":
|
||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
|
||||
for mat in state["Q"]:
|
||||
if len(mat) > 0:
|
||||
grad = torch.tensordot(
|
||||
grad,
|
||||
mat,
|
||||
dims=[[0], [0]],
|
||||
)
|
||||
else:
|
||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||
grad = grad.permute(permute_order)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
grad = grad.reshape(original_shape)
|
||||
return grad
|
||||
|
||||
def update_preconditioner(
|
||||
self,
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=10000,
|
||||
merge_dims=False,
|
||||
precondition_1d=False,
|
||||
):
|
||||
"""
|
||||
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||
"""
|
||||
if grad.dim() == 1:
|
||||
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||
state["GG"][0].lerp_(
|
||||
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
|
||||
)
|
||||
else:
|
||||
if merge_dims:
|
||||
new_grad = self.merge_dims(grad, max_precond_dim)
|
||||
for idx, sh in enumerate(new_grad.shape):
|
||||
if sh <= max_precond_dim:
|
||||
outer_product = torch.tensordot(
|
||||
new_grad,
|
||||
new_grad,
|
||||
dims=[
|
||||
[
|
||||
*chain(
|
||||
range(idx), range(idx + 1, len(new_grad.shape))
|
||||
)
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
)
|
||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||
else:
|
||||
for idx, sh in enumerate(grad.shape):
|
||||
if sh <= max_precond_dim:
|
||||
outer_product = torch.tensordot(
|
||||
grad,
|
||||
grad,
|
||||
# Contracts across all dimensions except for k.
|
||||
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
|
||||
* 2,
|
||||
)
|
||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||
|
||||
if state["Q"] is None:
|
||||
state["Q"] = self.get_orthogonal_matrix(state["GG"])
|
||||
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
|
||||
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||
state, max_precond_dim, merge_dims
|
||||
)
|
||||
|
||||
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||
"""
|
||||
Projects the gradient back to the original space.
|
||||
"""
|
||||
original_shape = grad.shape
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
for mat in state["Q"]:
|
||||
if len(mat) > 0:
|
||||
grad = torch.tensordot(
|
||||
grad,
|
||||
mat,
|
||||
dims=[[0], [1]],
|
||||
)
|
||||
else:
|
||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||
grad = grad.permute(permute_order)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
grad = grad.reshape(original_shape)
|
||||
return grad
|
||||
|
||||
def get_orthogonal_matrix(self, mat):
|
||||
"""
|
||||
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
||||
"""
|
||||
matrix = []
|
||||
for m in mat:
|
||||
if len(m) == 0:
|
||||
matrix.append([])
|
||||
continue
|
||||
if m.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = m.data.dtype
|
||||
original_device = m.data.device
|
||||
matrix.append(m.data.float())
|
||||
else:
|
||||
float_data = True
|
||||
matrix.append(m.data)
|
||||
|
||||
final = []
|
||||
for m in matrix:
|
||||
if len(m) == 0:
|
||||
final.append([])
|
||||
continue
|
||||
try:
|
||||
_, Q = torch.linalg.eigh(
|
||||
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||
)
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
_, Q = torch.linalg.eigh(
|
||||
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||
)
|
||||
Q = Q.to(m.dtype)
|
||||
Q = torch.flip(Q, [1])
|
||||
|
||||
if not float_data:
|
||||
Q = Q.to(original_device).type(original_type)
|
||||
final.append(Q)
|
||||
return final
|
||||
|
||||
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
|
||||
"""
|
||||
Computes the eigenbases of the preconditioner using one round of power iteration
|
||||
followed by torch.linalg.qr decomposition.
|
||||
"""
|
||||
precond_list = state["GG"]
|
||||
orth_list = state["Q"]
|
||||
|
||||
matrix = []
|
||||
orth_matrix = []
|
||||
for m, o in zip(precond_list, orth_list):
|
||||
if len(m) == 0:
|
||||
matrix.append([])
|
||||
orth_matrix.append([])
|
||||
continue
|
||||
if m.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = m.data.dtype
|
||||
original_device = m.data.device
|
||||
matrix.append(m.data.float())
|
||||
orth_matrix.append(o.data.float())
|
||||
else:
|
||||
float_data = True
|
||||
matrix.append(m.data.float())
|
||||
orth_matrix.append(o.data.float())
|
||||
|
||||
orig_shape = state["exp_avg_sq"].shape
|
||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
|
||||
if merge_dims:
|
||||
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
final = []
|
||||
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
|
||||
if len(m) == 0:
|
||||
final.append([])
|
||||
continue
|
||||
est_eig = torch.diag(o.T @ m @ o)
|
||||
sort_idx = torch.argsort(est_eig, descending=True)
|
||||
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
||||
o = o[:, sort_idx]
|
||||
power_iter = m @ o
|
||||
Q, _ = torch.linalg.qr(power_iter)
|
||||
|
||||
if not float_data:
|
||||
Q = Q.to(original_device).type(original_type)
|
||||
final.append(Q)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
|
||||
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
return final
|
||||
@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import is_hopper, with_temp_dir
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 50,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -144,146 +144,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||
@with_temp_dir
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -305,7 +165,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -371,7 +231,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -413,6 +273,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.skip("disabled due to upstream issue")
|
||||
@with_temp_dir
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -421,7 +282,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"adapter": "qlora",
|
||||
"mean_resizing_embeddings": True,
|
||||
"load_in_4bit": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
@@ -437,7 +297,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "</s>",
|
||||
"pad_token": "<|end_of_text|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -447,7 +307,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -513,7 +373,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -572,7 +432,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 100,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_3_1, with_temp_dir
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_3_1
|
||||
@require_torch_2_1_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -65,44 +65,3 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_soap(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "soap",
|
||||
"optim_soap_beta1": 0.95,
|
||||
"optim_soap_beta2": 0.95,
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@@ -9,8 +9,6 @@ from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -37,18 +35,13 @@ def most_recent_subdir(path):
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_3_1(test_case):
|
||||
def require_torch_2_1_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
"""
|
||||
|
||||
def is_min_2_3_1():
|
||||
def is_min_2_1_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.3.1"
|
||||
return torch_version >= "2.1.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
|
||||
@@ -367,44 +367,43 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
def test_load_local_hub_with_revision(self):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user