Compare commits
5 Commits
revert-290
...
kwargs-ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4dc75cc713 | ||
|
|
36cbe13d18 | ||
|
|
2c408b5c5e | ||
|
|
942005f526 | ||
|
|
9394983633 |
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -92,7 +92,7 @@ jobs:
|
|||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 60
|
timeout-minutes: 120
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -116,7 +116,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==0.71.8 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -26,3 +26,5 @@ timeout: 86400
|
|||||||
# Preprocess specific configurations
|
# Preprocess specific configurations
|
||||||
memory_preprocess: 32
|
memory_preprocess: 32
|
||||||
timeout_preprocess: 14400
|
timeout_preprocess: 14400
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -56,3 +55,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -56,3 +56,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -56,3 +56,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -55,3 +55,5 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -79,3 +79,5 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -62,3 +62,5 @@ saves_per_epoch: 1
|
|||||||
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -69,3 +68,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -50,3 +50,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -66,3 +66,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -62,3 +62,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -54,3 +54,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -55,3 +55,5 @@ saves_per_epoch: 1
|
|||||||
deepspeed: deepspeed_configs/zero2.json
|
deepspeed: deepspeed_configs/zero2.json
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -46,3 +46,5 @@ evals_per_epoch: 2
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -55,3 +55,5 @@ saves_per_epoch: 1
|
|||||||
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -52,3 +52,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -52,3 +52,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -67,3 +67,5 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -53,3 +53,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -58,3 +58,5 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -57,3 +57,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -77,3 +77,5 @@ fsdp_config:
|
|||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -72,3 +72,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot_id|>
|
eos_token: <|eot_id|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -42,3 +42,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -71,3 +71,5 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -83,3 +83,5 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -61,3 +61,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -65,3 +65,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ special_tokens:
|
|||||||
|
|
||||||
use_ray: true
|
use_ray: true
|
||||||
ray_num_workers: 4
|
ray_num_workers: 4
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -63,3 +63,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -57,3 +57,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -61,3 +61,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -62,3 +62,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ fsdp_config:
|
|||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ fsdp_config:
|
|||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -54,3 +54,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -75,3 +75,5 @@ llmcompressor:
|
|||||||
]
|
]
|
||||||
start: 0
|
start: 0
|
||||||
save_compressed: true
|
save_compressed: true
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -86,3 +86,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -90,3 +90,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -83,3 +83,5 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -86,3 +86,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -84,3 +84,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -82,3 +82,5 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -87,3 +87,5 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -70,3 +70,5 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -61,3 +61,5 @@ flash_attention: true
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -48,3 +48,5 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
tokens:
|
tokens:
|
||||||
save_safetensors: False
|
save_safetensors: False
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -53,3 +53,5 @@ special_tokens:
|
|||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
tokens:
|
tokens:
|
||||||
- "<|im_start|>"
|
- "<|im_start|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -43,3 +43,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -80,3 +80,5 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<|im_start|>"
|
bos_token: "<|im_start|>"
|
||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -74,3 +74,5 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -69,3 +69,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -56,3 +56,5 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -72,3 +72,5 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -77,3 +77,5 @@ fsdp_config:
|
|||||||
fsdp_forward_prefetch: false
|
fsdp_forward_prefetch: false
|
||||||
fsdp_backward_prefetch: BACKWARD_PRE
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -81,3 +81,5 @@ saves_per_epoch: 1
|
|||||||
deepspeed: deepspeed_configs/zero2.json
|
deepspeed: deepspeed_configs/zero2.json
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -51,3 +51,5 @@ special_tokens:
|
|||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
tokens:
|
tokens:
|
||||||
- "<|im_start|>"
|
- "<|im_start|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,3 +64,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -50,3 +50,5 @@ weight_decay: 0.05
|
|||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <custom_token_7>
|
pad_token: <custom_token_7>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -63,3 +63,5 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 4
|
saves_per_epoch: 4
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -57,3 +57,5 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -60,3 +60,5 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -57,3 +57,5 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -71,3 +71,5 @@ fsdp_config:
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -59,3 +59,5 @@ warmup_ratio: 0.2
|
|||||||
debug: true
|
debug: true
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -55,3 +55,5 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <pad>
|
pad_token: <pad>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -54,3 +54,5 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -55,3 +55,5 @@ eval_steps: 100
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -67,3 +67,5 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -50,3 +49,5 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -67,3 +67,5 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -76,3 +76,5 @@ fsdp_config:
|
|||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -66,3 +66,5 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ hf_transfer
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.23.3
|
||||||
|
|
||||||
modal==0.70.5
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GCCallback,
|
GCCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.save_first_step:
|
||||||
|
callbacks.append(SaveModelOnFirstStepCallback())
|
||||||
|
|
||||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
|
|||||||
from Apple's ML team.
|
from Apple's ML team.
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
@@ -32,7 +34,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
self.patch_llama_like(cfg.model_config_type)
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.patch import cce_patch
|
from cut_cross_entropy.transformers.patch import cce_patch
|
||||||
|
|
||||||
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
cce_patch(cfg.model_config_type)
|
||||||
|
|
||||||
|
def patch_llama_like(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generic patch for model architectures with causal lm similar to llama
|
||||||
|
"""
|
||||||
|
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||||
|
|
||||||
|
def patch_generic(
|
||||||
|
maybe_model, patch_options, model_type: str
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
import cut_cross_entropy.transformers.llama
|
||||||
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and CausalLM class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
|
module = __import__(
|
||||||
|
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
|
||||||
|
)
|
||||||
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
|
|
||||||
|
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
|
||||||
|
patch_options
|
||||||
|
)
|
||||||
|
|
||||||
|
model_cls.forward = cce_forward
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import ForCausalLM class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if model_type not in PATCH_FNS:
|
||||||
|
LOG.warning_once(
|
||||||
|
"Setting up generic cce patch for model type: %s", model_type
|
||||||
|
)
|
||||||
|
LOG.warning_once(
|
||||||
|
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
||||||
|
)
|
||||||
|
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ except ImportError:
|
|||||||
TransformersKwargs,
|
TransformersKwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
|
|
||||||
|
|
||||||
def kldiv_forward_llama_like(
|
def kldiv_forward_llama_like(
|
||||||
self,
|
self,
|
||||||
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
|
|||||||
def apply_kernel(model_type):
|
def apply_kernel(model_type):
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
model_cls.forward = kldiv_forward_llama_like
|
model_cls.forward = kldiv_forward_llama_like
|
||||||
|
|||||||
@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
import inspect
|
from .args import LigerArgs
|
||||||
import sys
|
from .plugin import LigerPlugin
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
__all__ = [
|
||||||
from axolotl.utils.logging import get_logger
|
"LigerArgs",
|
||||||
|
"LigerPlugin",
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
]
|
||||||
from .utils import patch_with_compile_disable
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
|
||||||
"""
|
|
||||||
Plugin for LIGER integraton with Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self):
|
|
||||||
return "axolotl.integrations.liger.LigerArgs"
|
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
|
||||||
if cfg.torch_compile:
|
|
||||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
|
||||||
import liger_kernel.ops.fused_linear_cross_entropy
|
|
||||||
|
|
||||||
patch_with_compile_disable(
|
|
||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
|
||||||
"fused_linear_cross_entropy_forward",
|
|
||||||
)
|
|
||||||
patch_with_compile_disable(
|
|
||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
|
||||||
"fused_linear_cross_entropy_backward",
|
|
||||||
)
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
|
||||||
kwargs = {}
|
|
||||||
if "rope" in liger_fn_sig.parameters:
|
|
||||||
kwargs["rope"] = cfg.liger_rope
|
|
||||||
if "cross_entropy" in liger_fn_sig.parameters:
|
|
||||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
|
||||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
|
||||||
kwargs["fused_linear_cross_entropy"] = (
|
|
||||||
cfg.liger_fused_linear_cross_entropy
|
|
||||||
)
|
|
||||||
if "rms_norm" in liger_fn_sig.parameters:
|
|
||||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
|
||||||
if "layer_norm" in liger_fn_sig.parameters:
|
|
||||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
|
||||||
if "geglu" in liger_fn_sig.parameters:
|
|
||||||
kwargs["geglu"] = cfg.liger_glu_activation
|
|
||||||
elif "swiglu" in liger_fn_sig.parameters:
|
|
||||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
|
||||||
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
|
||||||
apply_liger_fn(**kwargs)
|
|
||||||
elif cfg.model_config_type == "jamba":
|
|
||||||
from transformers.models.jamba import modeling_jamba
|
|
||||||
|
|
||||||
from .models.jamba import lce_forward as jamba_lce_forward
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
|
||||||
)
|
|
||||||
modeling_mod = sys.modules[model.__class__.__module__]
|
|
||||||
|
|
||||||
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
|
||||||
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
|
||||||
# nn.CrossEntropyLoss in the forward method.
|
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
|
||||||
elif cfg.model_config_type == "llama4":
|
|
||||||
from axolotl.integrations.liger.models.llama4 import (
|
|
||||||
apply_liger_kernel_to_llama4,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_llama4(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3":
|
|
||||||
from axolotl.integrations.liger.models.qwen3 import (
|
|
||||||
apply_liger_kernel_to_qwen3,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_moe(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "granitemoe":
|
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
|
||||||
|
|
||||||
apply_liger_kernel_to_granite(
|
|
||||||
rope=cfg.liger_rope,
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
swiglu=cfg.liger_glu_activation,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
|
||||||
)
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user