Compare commits
1 Commits
kwargs-ref
...
revert-290
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f6d917a99 |
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -92,7 +92,7 @@ jobs:
|
|||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 60
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -116,7 +116,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==0.71.8 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -26,5 +26,3 @@ timeout: 86400
|
|||||||
# Preprocess specific configurations
|
# Preprocess specific configurations
|
||||||
memory_preprocess: 32
|
memory_preprocess: 32
|
||||||
timeout_preprocess: 14400
|
timeout_preprocess: 14400
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -55,5 +56,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -56,5 +56,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -56,5 +56,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -55,5 +55,3 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -79,5 +79,3 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -62,5 +62,3 @@ saves_per_epoch: 1
|
|||||||
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -68,5 +69,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -50,5 +50,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -66,5 +66,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -62,5 +62,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -54,5 +54,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -55,5 +55,3 @@ saves_per_epoch: 1
|
|||||||
deepspeed: deepspeed_configs/zero2.json
|
deepspeed: deepspeed_configs/zero2.json
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -46,5 +46,3 @@ evals_per_epoch: 2
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -55,5 +55,3 @@ saves_per_epoch: 1
|
|||||||
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -52,5 +52,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -52,5 +52,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -67,5 +67,3 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -53,5 +53,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -58,5 +58,3 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -57,5 +57,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -77,5 +77,3 @@ fsdp_config:
|
|||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -72,5 +72,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot_id|>
|
eos_token: <|eot_id|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -42,5 +42,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -71,5 +71,3 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -83,5 +83,3 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -61,5 +61,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -65,5 +65,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ special_tokens:
|
|||||||
|
|
||||||
use_ray: true
|
use_ray: true
|
||||||
ray_num_workers: 4
|
ray_num_workers: 4
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -63,5 +63,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -57,5 +57,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -61,5 +61,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -62,5 +62,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ fsdp_config:
|
|||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ fsdp_config:
|
|||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -54,5 +54,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|end_of_text|>"
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -75,5 +75,3 @@ llmcompressor:
|
|||||||
]
|
]
|
||||||
start: 0
|
start: 0
|
||||||
save_compressed: true
|
save_compressed: true
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -86,5 +86,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -90,5 +90,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -83,5 +83,3 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -86,5 +86,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -84,5 +84,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -82,5 +82,3 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -87,5 +87,3 @@ fsdp_config:
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -70,5 +70,3 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -61,5 +61,3 @@ flash_attention: true
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -48,5 +48,3 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
tokens:
|
tokens:
|
||||||
save_safetensors: False
|
save_safetensors: False
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -53,5 +53,3 @@ special_tokens:
|
|||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
tokens:
|
tokens:
|
||||||
- "<|im_start|>"
|
- "<|im_start|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -43,5 +43,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -80,5 +80,3 @@ weight_decay: 0.0
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<|im_start|>"
|
bos_token: "<|im_start|>"
|
||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -74,5 +74,3 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -69,5 +69,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -56,5 +56,3 @@ evals_per_epoch: 1
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -72,5 +72,3 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -77,5 +77,3 @@ fsdp_config:
|
|||||||
fsdp_forward_prefetch: false
|
fsdp_forward_prefetch: false
|
||||||
fsdp_backward_prefetch: BACKWARD_PRE
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -81,5 +81,3 @@ saves_per_epoch: 1
|
|||||||
deepspeed: deepspeed_configs/zero2.json
|
deepspeed: deepspeed_configs/zero2.json
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -51,5 +51,3 @@ special_tokens:
|
|||||||
eos_token: "<|im_end|>"
|
eos_token: "<|im_end|>"
|
||||||
tokens:
|
tokens:
|
||||||
- "<|im_start|>"
|
- "<|im_start|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -64,5 +64,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -50,5 +50,3 @@ weight_decay: 0.05
|
|||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <custom_token_7>
|
pad_token: <custom_token_7>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -63,5 +63,3 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 4
|
saves_per_epoch: 4
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -57,5 +57,3 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -60,5 +60,3 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -57,5 +57,3 @@ weight_decay: 0.1
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -71,5 +71,3 @@ fsdp_config:
|
|||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -59,5 +59,3 @@ warmup_ratio: 0.2
|
|||||||
debug: true
|
debug: true
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
resize_token_embeddings_to_32x: true
|
resize_token_embeddings_to_32x: true
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -55,5 +55,3 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <pad>
|
pad_token: <pad>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -54,5 +54,3 @@ warmup_steps: 10
|
|||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -55,5 +55,3 @@ eval_steps: 100
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -67,5 +67,3 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -49,5 +50,3 @@ evals_per_epoch:
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
|
|||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -67,5 +67,3 @@ evals_per_epoch: 4
|
|||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -76,5 +76,3 @@ fsdp_config:
|
|||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -66,5 +66,3 @@ fsdp_config:
|
|||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ hf_transfer
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.23.3
|
||||||
|
|
||||||
modal==1.0.2
|
modal==0.70.5
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from axolotl.utils.callbacks import (
|
|||||||
GCCallback,
|
GCCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveModelOnFirstStepCallback,
|
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
@@ -136,8 +135,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.save_first_step:
|
|
||||||
callbacks.append(SaveModelOnFirstStepCallback())
|
|
||||||
|
|
||||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -19,13 +19,11 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
|
|||||||
from Apple's ML team.
|
from Apple's ML team.
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
@@ -34,7 +32,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -86,7 +84,6 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
self.patch_llama_like(cfg.model_config_type)
|
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.patch import cce_patch
|
from cut_cross_entropy.transformers.patch import cce_patch
|
||||||
|
|
||||||
@@ -96,48 +93,3 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
cce_patch(cfg.model_config_type)
|
||||||
|
|
||||||
def patch_llama_like(
|
|
||||||
self,
|
|
||||||
model_type: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Generic patch for model architectures with causal lm similar to llama
|
|
||||||
"""
|
|
||||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
|
||||||
|
|
||||||
def patch_generic(
|
|
||||||
maybe_model, patch_options, model_type: str
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
import cut_cross_entropy.transformers.llama
|
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and CausalLM class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
|
||||||
module = __import__(
|
|
||||||
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
|
|
||||||
)
|
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
|
||||||
|
|
||||||
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
|
|
||||||
patch_options
|
|
||||||
)
|
|
||||||
|
|
||||||
model_cls.forward = cce_forward
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not import ForCausalLM class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if model_type not in PATCH_FNS:
|
|
||||||
LOG.warning_once(
|
|
||||||
"Setting up generic cce patch for model type: %s", model_type
|
|
||||||
)
|
|
||||||
LOG.warning_once(
|
|
||||||
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
|
||||||
)
|
|
||||||
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ except ImportError:
|
|||||||
TransformersKwargs,
|
TransformersKwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
|
||||||
|
|
||||||
|
|
||||||
def kldiv_forward_llama_like(
|
def kldiv_forward_llama_like(
|
||||||
self,
|
self,
|
||||||
@@ -99,7 +97,7 @@ def kldiv_forward_llama_like(
|
|||||||
def apply_kernel(model_type):
|
def apply_kernel(model_type):
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
model_cls.forward = kldiv_forward_llama_like
|
model_cls.forward = kldiv_forward_llama_like
|
||||||
|
|||||||
@@ -18,10 +18,170 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
from .args import LigerArgs
|
import inspect
|
||||||
from .plugin import LigerPlugin
|
import sys
|
||||||
|
|
||||||
__all__ = [
|
from axolotl.integrations.base import BasePlugin
|
||||||
"LigerArgs",
|
from axolotl.utils.logging import get_logger
|
||||||
"LigerPlugin",
|
|
||||||
]
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LigerPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for LIGER integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
if cfg.torch_compile:
|
||||||
|
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||||
|
import liger_kernel.ops.fused_linear_cross_entropy
|
||||||
|
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_forward",
|
||||||
|
)
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_backward",
|
||||||
|
)
|
||||||
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
|
kwargs = {}
|
||||||
|
if "rope" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rope"] = cfg.liger_rope
|
||||||
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
|
kwargs["fused_linear_cross_entropy"] = (
|
||||||
|
cfg.liger_fused_linear_cross_entropy
|
||||||
|
)
|
||||||
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
|
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||||
|
if "geglu" in liger_fn_sig.parameters:
|
||||||
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
|
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
||||||
|
apply_liger_fn(**kwargs)
|
||||||
|
elif cfg.model_config_type == "jamba":
|
||||||
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|
||||||
|
from .models.jamba import lce_forward as jamba_lce_forward
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
||||||
|
)
|
||||||
|
modeling_mod = sys.modules[model.__class__.__module__]
|
||||||
|
|
||||||
|
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
|
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
elif cfg.model_config_type == "llama4":
|
||||||
|
from axolotl.integrations.liger.models.llama4 import (
|
||||||
|
apply_liger_kernel_to_llama4,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_llama4(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3":
|
||||||
|
from axolotl.integrations.liger.models.qwen3 import (
|
||||||
|
apply_liger_kernel_to_qwen3,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_moe(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "granitemoe":
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||||
|
|
||||||
|
apply_liger_kernel_to_granite(
|
||||||
|
rope=cfg.liger_rope,
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
swiglu=cfg.liger_glu_activation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user