Compare commits
21 Commits
coderabbit
...
feat/glm45
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be1f8db913 | ||
|
|
f2155eaf79 | ||
|
|
92ee4256f7 | ||
|
|
efeb5a4e41 | ||
|
|
faaff6c792 | ||
|
|
43cef27458 | ||
|
|
07c41a6c2a | ||
|
|
bbd3486f57 | ||
|
|
3750d7dd64 | ||
|
|
2197b0bf89 | ||
|
|
a526647b31 | ||
|
|
8069177284 | ||
|
|
a28eb600e9 | ||
|
|
4b16f363bc | ||
|
|
272a456ec0 | ||
|
|
7e83268662 | ||
|
|
b2a8c37a27 | ||
|
|
603166d9c5 | ||
|
|
e8c9517ac8 | ||
|
|
0bbad9202c | ||
|
|
cb042e9775 |
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,6 +12,9 @@ jobs:
|
|||||||
build-deploy:
|
build-deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: cleanup node
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
|
|||||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,6 +11,7 @@ on:
|
|||||||
- '_quarto.yml'
|
- '_quarto.yml'
|
||||||
- docs/scripts/generate_config_docs.py
|
- docs/scripts/generate_config_docs.py
|
||||||
- src/axolotl/utils/schemas/**.py
|
- src/axolotl/utils/schemas/**.py
|
||||||
|
- .github/workflows/preview-docs.yml
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
checks: write
|
checks: write
|
||||||
@@ -27,6 +28,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
|
- name: cleanup node
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||||
|
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -114,7 +114,7 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
@@ -196,7 +196,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
|
|||||||
48
examples/glm45/README.md
Normal file
48
examples/glm45/README.md
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Finetune GLM4.5 with Axolotl
|
||||||
|
|
||||||
|
[UNSTABLE]
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LoRA SFT (4xH200 @ 84GB/GPU)
|
||||||
|
axolotl train examples/glm45/glm4.5-lora-fsdp2.yaml
|
||||||
|
|
||||||
|
# FFT SFT (4xH200)
|
||||||
|
# Checkpointing error on backward pass
|
||||||
|
# Without checkpointing => OOM
|
||||||
|
axolotl train examples/glm45/glm4.5-fft-fsdp2.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
In addition to normal OpenAI Messages format, GLM4.5 support an extra parameter for thinking in assistant section.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": "...", // or have </think>...</think> in `content`
|
||||||
|
"content": "...",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The role name for tools in this template is `tool`.
|
||||||
|
- You will see this Axolotl WARNING. This is to be as expected as the template does not use EOS.
|
||||||
|
```bash
|
||||||
|
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
|
||||||
|
```
|
||||||
|
- Make sure you set the below extra attributes if needed
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
message_property_mappings:
|
||||||
|
role: role
|
||||||
|
content: content
|
||||||
|
|
||||||
|
# tool_calls: tool_calls # uncomment if using tools
|
||||||
|
# reasoning_content: reasoning_content # uncomment if have reasoning
|
||||||
|
|
||||||
|
# Uncomment if training on tool role (you would rarely if ever need this)
|
||||||
|
# eot_tokens:
|
||||||
|
# - <|observation|>
|
||||||
|
```
|
||||||
59
examples/glm45/glm4.5-fft-fsdp2.yaml
Normal file
59
examples/glm45/glm4.5-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
base_model: zai-org/GLM-4.5-Air
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: winglian/pirate-ultrachat-10k
|
||||||
|
type: chat_template
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
74
examples/glm45/glm4.5-lora-fsdp2.yaml
Normal file
74
examples/glm45/glm4.5-lora-fsdp2.yaml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
base_model: zai-org/GLM-4.5-Air
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: winglian/pirate-ultrachat-10k
|
||||||
|
type: chat_template
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
reshard_after_forward: true
|
||||||
|
# activation_checkpointing: false
|
||||||
@@ -32,6 +32,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ flex_attention: true
|
|||||||
flex_attn_compile_kwargs:
|
flex_attn_compile_kwargs:
|
||||||
dynamic: false
|
dynamic: false
|
||||||
mode: max-autotune-no-cudagraphs
|
mode: max-autotune-no-cudagraphs
|
||||||
save_strategy: no
|
|
||||||
torch_compile: true
|
torch_compile: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-adamw-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: adamw
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# AdamW optimizer settings (standard LR for AdamW)
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
learning_rate: 0.0002
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-muon-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: muon
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# Muon optimizer settings
|
||||||
|
optimizer: muon
|
||||||
|
learning_rate: 0.02
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
@@ -20,15 +20,16 @@ deepspeed>=0.17.0
|
|||||||
trl==0.25.0
|
trl==0.25.0
|
||||||
hf_xet==1.2.0
|
hf_xet==1.2.0
|
||||||
kernels>=0.9.0
|
kernels>=0.9.0
|
||||||
trackio
|
trackio>=0.13.0
|
||||||
|
typing_extensions>=4.14.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.49.1
|
gradio>=6.2.0,<7.0
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic>=2.10.6
|
pydantic>=2.10.6,<2.12
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
@@ -67,8 +68,7 @@ openenv-core==0.1.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.7
|
axolotl-contribs-lgpl==0.0.7
|
||||||
axolotl-contribs-mit==0.0.5
|
axolotl-contribs-mit==0.0.6
|
||||||
|
|
||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.tee import prepare_debug_log
|
from axolotl.utils.tee import prepare_debug_log
|
||||||
|
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -227,6 +228,7 @@ def load_cfg(
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
|
"fp8": compute_supports_fp8(),
|
||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
@@ -245,6 +247,7 @@ def load_cfg(
|
|||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
|
setup_trackio_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||||
@@ -259,3 +262,11 @@ def load_cfg(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_supports_fp8() -> bool:
|
||||||
|
try:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
return compute_capability >= (9, 0)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
|||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
|||||||
outputs=[masked_preview, html_out],
|
outputs=[masked_preview, html_out],
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
|
"glm4_moe": "Glm4MoeMoE",
|
||||||
"deepseek_v3": "DeepseekV3MoE",
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from axolotl.utils import (
|
|||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_opentelemetry_available,
|
is_opentelemetry_available,
|
||||||
|
is_trackio_available,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_trackio and is_trackio_available():
|
||||||
|
from axolotl.utils.callbacks.trackio_ import (
|
||||||
|
SaveAxolotlConfigtoTrackioCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||||
from axolotl.utils.callbacks.opentelemetry import (
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
OpenTelemetryMetricsCallback,
|
OpenTelemetryMetricsCallback,
|
||||||
@@ -281,11 +290,22 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
if self.cfg.optimizer == "muon":
|
if self.cfg.optimizer == "muon":
|
||||||
from axolotl.contribs.mit.muon import (
|
_, device_mesh = build_parallelism_config(self.cfg)
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
if device_mesh is not None:
|
||||||
|
from axolotl.contribs.mit.muon.dist_muon import (
|
||||||
|
DistMuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = DistMuonOptimizerFactory
|
||||||
|
optimizer_kwargs["device_mesh"] = device_mesh
|
||||||
|
else:
|
||||||
|
from axolotl.contribs.mit.muon import (
|
||||||
|
MuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = MuonOptimizerFactory
|
||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "dion":
|
elif self.cfg.optimizer == "dion":
|
||||||
from axolotl.contribs.mit.dion import (
|
from axolotl.contribs.mit.dion import (
|
||||||
@@ -423,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
report_to.append("tensorboard")
|
report_to.append("tensorboard")
|
||||||
if self.cfg.use_comet:
|
if self.cfg.use_comet:
|
||||||
report_to.append("comet_ml")
|
report_to.append("comet_ml")
|
||||||
|
if self.cfg.use_trackio:
|
||||||
|
report_to.append("trackio")
|
||||||
|
|
||||||
training_args_kwargs["report_to"] = report_to
|
training_args_kwargs["report_to"] = report_to
|
||||||
|
|
||||||
@@ -430,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
elif self.cfg.use_mlflow:
|
elif self.cfg.use_mlflow:
|
||||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
|
elif self.cfg.use_trackio:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["run_name"] = None
|
training_args_kwargs["run_name"] = None
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@@ -603,6 +604,7 @@ class AxolotlTrainer(
|
|||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||||
|
|
||||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||||
@@ -613,7 +615,18 @@ class AxolotlTrainer(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Metric reduction must be one of [mean, min, max, sum]"
|
"Metric reduction must be one of [mean, min, max, sum]"
|
||||||
)
|
)
|
||||||
logs[key] = round(fn(values).item(), 4)
|
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||||
|
|
||||||
|
if "loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||||
|
except OverflowError:
|
||||||
|
logs["ppl"] = float("inf")
|
||||||
|
if "eval_loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||||
|
except OverflowError:
|
||||||
|
logs["eval_ppl"] = float("inf")
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Add memory usage
|
# Add memory usage
|
||||||
|
|||||||
@@ -36,4 +36,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
if cfg.dpo_use_logits_to_keep is not None:
|
if cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||||
|
if cfg.dpo_use_liger_kernel is not None:
|
||||||
|
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ plugins:
|
|||||||
- gemma3n_text
|
- gemma3n_text
|
||||||
- glm
|
- glm
|
||||||
- glm4
|
- glm4
|
||||||
|
- glm_moe
|
||||||
- glm4_moe
|
- glm4_moe
|
||||||
- glm4v
|
- glm4v
|
||||||
- glm4v_moe
|
- glm4v_moe
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
|
|||||||
if cfg.dense_mixer:
|
if cfg.dense_mixer:
|
||||||
if not importlib.util.find_spec("densemixer"):
|
if not importlib.util.find_spec("densemixer"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
"DenseMixer is not installed. Install it with `pip install densemixer`"
|
||||||
)
|
)
|
||||||
|
|
||||||
from densemixer.patching import (
|
from densemixer.patching import (
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
|
"glm4_moe",
|
||||||
"smollm3",
|
"smollm3",
|
||||||
"granite",
|
"granite",
|
||||||
"granitemoe",
|
"granitemoe",
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ def is_opentelemetry_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_trackio_available():
|
||||||
|
return importlib.util.find_spec("trackio") is not None
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_version() -> tuple[int, int, int]:
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Get Pytorch version as a tuple of (major, minor, patch).
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
|||||||
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
44
src/axolotl/utils/callbacks/trackio_.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Trackio module for trainer callbacks"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import trackio
|
||||||
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.environment import is_package_version_ge
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
|
||||||
|
"""Callback for trackio integration"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: "AxolotlTrainingArguments",
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
try:
|
||||||
|
if not is_package_version_ge("trackio", "0.11.0"):
|
||||||
|
LOG.warning(
|
||||||
|
"Trackio version 0.11.0 or higher is required to save config files. "
|
||||||
|
"Please upgrade trackio: pip install --upgrade trackio"
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
|
trackio.save(self.axolotl_config_path)
|
||||||
|
LOG.info("The Axolotl config has been saved to Trackio.")
|
||||||
|
except (FileNotFoundError, ConnectionError, AttributeError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
|
||||||
|
return control
|
||||||
@@ -188,7 +188,10 @@ def handle_long_seq_in_dataset(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences handled according to the excess_length_strategy value:
|
||||||
|
'drop' (default) excludes any sequence longer than sequence_len
|
||||||
|
'truncate' truncates them down to sequence_len
|
||||||
|
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
hasattr(dataset, "column_names")
|
hasattr(dataset, "column_names")
|
||||||
@@ -206,10 +209,13 @@ def handle_long_seq_in_dataset(
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
raise_on_drop=excess_length_strategy == "raise",
|
||||||
)
|
)
|
||||||
|
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
@@ -228,9 +234,13 @@ def handle_long_seq_in_dataset(
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
action = (
|
||||||
|
"Checking Sequence Lengths"
|
||||||
|
if excess_length_strategy == "raise"
|
||||||
|
else "Dropping Long Sequences"
|
||||||
|
)
|
||||||
|
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
||||||
|
|
||||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
|
||||||
if excess_length_strategy == "truncate":
|
if excess_length_strategy == "truncate":
|
||||||
process_fn = functools.partial(
|
process_fn = functools.partial(
|
||||||
truncate_long_seq,
|
truncate_long_seq,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from accelerate.utils import is_fp8_available
|
||||||
from annotated_types import MinLen
|
from annotated_types import MinLen
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -33,6 +34,7 @@ from axolotl.utils.schemas.integrations import (
|
|||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
OpenTelemetryConfig,
|
OpenTelemetryConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
|
TrackioConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
||||||
@@ -62,6 +64,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
|
TrackioConfig,
|
||||||
OpenTelemetryConfig,
|
OpenTelemetryConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
@@ -173,6 +176,12 @@ class AxolotlInputConfig(
|
|||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
|
|
||||||
|
dpo_use_liger_kernel: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."},
|
||||||
|
)
|
||||||
|
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
dpo_generate_during_eval: bool | None = None
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
|
||||||
@@ -445,10 +454,10 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_sequence_len: int | None = Field(
|
eval_sequence_len: int | None = Field(
|
||||||
@@ -1092,6 +1101,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fp8(self):
|
||||||
|
if self.fp8 and not self.capabilities.fp8:
|
||||||
|
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
|
||||||
|
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
|
||||||
|
raise ValueError(
|
||||||
|
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||||
|
|||||||
@@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel):
|
|||||||
"description": "Port for the Prometheus metrics HTTP server"
|
"description": "Port for the Prometheus metrics HTTP server"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrackioConfig(BaseModel):
|
||||||
|
"""Trackio configuration subset"""
|
||||||
|
|
||||||
|
use_trackio: bool | None = None
|
||||||
|
trackio_project_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Your trackio project name"},
|
||||||
|
)
|
||||||
|
trackio_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Set the name of your trackio run"},
|
||||||
|
)
|
||||||
|
trackio_space_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -751,12 +751,19 @@ class OptimizationValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_muon_deepspeed_fsdp(cls, data):
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
if data.get("optimizer") == "muon" and (
|
if data.get("optimizer") == "muon":
|
||||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
if data.get("deepspeed"):
|
||||||
):
|
raise ValueError(
|
||||||
raise ValueError(
|
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
)
|
||||||
)
|
if data.get("fsdp") or data.get("fsdp_config"):
|
||||||
|
fsdp_version = data.get("fsdp_version")
|
||||||
|
if fsdp_version is None:
|
||||||
|
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||||
|
if str(fsdp_version) != "2":
|
||||||
|
raise ValueError(
|
||||||
|
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -840,40 +847,6 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
|
||||||
fsdp_config = data.get("fsdp_config") or {}
|
|
||||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
|
||||||
LOG.warning(
|
|
||||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
|
||||||
"Please configure `fsdp_version` as a top-level field."
|
|
||||||
)
|
|
||||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
|
||||||
if fsdp_config := data.get("fsdp_config"):
|
|
||||||
should_fix = False
|
|
||||||
for key, _ in fsdp_config.items():
|
|
||||||
if key.startswith("fsdp_"):
|
|
||||||
should_fix = True
|
|
||||||
LOG.warning_once(
|
|
||||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
|
||||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
|
||||||
)
|
|
||||||
if should_fix:
|
|
||||||
update_fsdp_config = {}
|
|
||||||
for key, value in fsdp_config.items():
|
|
||||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
|
||||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
|
||||||
else:
|
|
||||||
update_fsdp_config[key] = value
|
|
||||||
data["fsdp_config"] = update_fsdp_config
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fsdp_offload_w_8bit_optimizer(self):
|
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||||
if (
|
if (
|
||||||
@@ -975,6 +948,40 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||||
|
fsdp_config = data.get("fsdp_config") or {}
|
||||||
|
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||||
|
LOG.warning(
|
||||||
|
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||||
|
"Please configure `fsdp_version` as a top-level field."
|
||||||
|
)
|
||||||
|
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||||
|
if fsdp_config := data.get("fsdp_config"):
|
||||||
|
should_fix = False
|
||||||
|
for key, _ in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_"):
|
||||||
|
should_fix = True
|
||||||
|
LOG.warning_once(
|
||||||
|
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||||
|
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||||
|
)
|
||||||
|
if should_fix:
|
||||||
|
update_fsdp_config = {}
|
||||||
|
for key, value in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||||
|
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||||
|
else:
|
||||||
|
update_fsdp_config[key] = value
|
||||||
|
data["fsdp_config"] = update_fsdp_config
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class SystemValidationMixin:
|
class SystemValidationMixin:
|
||||||
"""Validation methods related to system and hardware configuration."""
|
"""Validation methods related to system and hardware configuration."""
|
||||||
|
|||||||
17
src/axolotl/utils/trackio_.py
Normal file
17
src/axolotl/utils/trackio_.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Module for trackio utilities"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def setup_trackio_env_vars(cfg: DictDefault):
|
||||||
|
for key in cfg.keys():
|
||||||
|
if key.startswith("trackio_"):
|
||||||
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value and isinstance(value, str) and len(value) > 0:
|
||||||
|
os.environ[key.upper()] = value
|
||||||
|
|
||||||
|
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
|
||||||
|
cfg.use_trackio = True
|
||||||
@@ -205,12 +205,15 @@ def add_length(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||||
"""
|
"""
|
||||||
Drop samples whose sequence length is either too long (> sequence_len)
|
Drop samples whose sequence length is either too long (> sequence_len)
|
||||||
or too short (< min_sequence_len).
|
or too short (< min_sequence_len).
|
||||||
|
|
||||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
|
||||||
|
If raise_on_drop is set, the code raises a ValueError if a sample is
|
||||||
|
encountered that is too long and would have been dropped.
|
||||||
"""
|
"""
|
||||||
min_sequence_len = min_sequence_len or 2
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
@@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
if isinstance(input_ids[0], int):
|
if isinstance(input_ids[0], int):
|
||||||
# Single example (input_ids is a list of int)
|
# Single example (input_ids is a list of int)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
|
if raise_on_drop and length > sequence_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||||
|
)
|
||||||
return min_sequence_len <= length <= sequence_len
|
return min_sequence_len <= length <= sequence_len
|
||||||
|
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
results = []
|
results = []
|
||||||
for seq in input_ids:
|
for seq in input_ids:
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
|
if raise_on_drop and length > sequence_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||||
|
)
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -474,10 +474,8 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
|||||||
|
|
||||||
assert trainer.optimizer_cls_and_kwargs is not None
|
assert trainer.optimizer_cls_and_kwargs is not None
|
||||||
|
|
||||||
from axolotl.contribs.mit.muon import (
|
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||||
Muon,
|
from axolotl.contribs.mit.muon.muon import Muon
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||||
assert optimizer_cls is MuonOptimizerFactory
|
assert optimizer_cls is MuonOptimizerFactory
|
||||||
@@ -556,10 +554,8 @@ class TestHFCausalTrainerBuilder:
|
|||||||
|
|
||||||
assert trainer.optimizer_cls_and_kwargs is not None
|
assert trainer.optimizer_cls_and_kwargs is not None
|
||||||
|
|
||||||
from axolotl.contribs.mit.muon import (
|
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||||
Muon,
|
from axolotl.contribs.mit.muon.muon import Muon
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||||
assert optimizer_cls is MuonOptimizerFactory
|
assert optimizer_cls is MuonOptimizerFactory
|
||||||
|
|||||||
168
tests/e2e/multigpu/test_dist_muon_fsdp2.py
Normal file
168
tests/e2e/multigpu/test_dist_muon_fsdp2.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def verify_training_success(temp_dir):
|
||||||
|
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||||
|
output_path = Path(temp_dir)
|
||||||
|
|
||||||
|
model_files = list(output_path.glob("*.bin")) + list(
|
||||||
|
output_path.glob("*.safetensors")
|
||||||
|
)
|
||||||
|
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||||
|
|
||||||
|
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||||
|
assert len(checkpoint_files) > 0, (
|
||||||
|
"No checkpoint files found - training may have failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
if tb_log_path:
|
||||||
|
event_files = sorted(os.listdir(tb_log_path))
|
||||||
|
if event_files:
|
||||||
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars
|
||||||
|
train_loss_df = df[df.tag == "train/train_loss"]
|
||||||
|
if len(train_loss_df) > 0:
|
||||||
|
final_loss = train_loss_df.value.values[-1]
|
||||||
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
f"Training loss is NaN: {final_loss}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistMuon:
|
||||||
|
"""Test class for DistMuon optimizer with FSDP2 functionality."""
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_fft_sft(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.02,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_sft(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.02,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
@@ -7,6 +7,7 @@ import unittest
|
|||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import encode_streaming, md5
|
from axolotl.utils.data import encode_streaming, md5
|
||||||
|
from axolotl.utils.trainer import drop_long_seq
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
@@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_excess_length_strategy(self):
|
||||||
|
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
|
||||||
|
|
||||||
|
# -- single sequence --
|
||||||
|
# This should work
|
||||||
|
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||||
|
drop_long_seq(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should return True, since data fits
|
||||||
|
dropped = drop_long_seq(data, 32)
|
||||||
|
self.assertTrue(dropped)
|
||||||
|
|
||||||
|
# This should raise
|
||||||
|
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should return False, since data doesn't fit
|
||||||
|
dropped = drop_long_seq(data, 15)
|
||||||
|
self.assertFalse(dropped)
|
||||||
|
|
||||||
|
# -- batch sequence --
|
||||||
|
# This should work
|
||||||
|
data = {
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
drop_long_seq(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should raise
|
||||||
|
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should keep the first but drop the second entry
|
||||||
|
dropped = drop_long_seq(data, 15)
|
||||||
|
self.assertEqual(dropped, [True, False])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
from axolotl.loaders.tokenizer import load_tokenizer
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
from axolotl.utils.data.sft import (
|
||||||
|
_load_tokenized_prepared_datasets,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.constants import (
|
from tests.constants import (
|
||||||
|
|||||||
@@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|||||||
@@ -123,6 +123,17 @@ class TestFSDPValidation:
|
|||||||
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
||||||
assert cfg.fsdp_config.reshard_after_forward is True
|
assert cfg.fsdp_config.reshard_after_forward is True
|
||||||
|
|
||||||
|
def test_muon_fsdp1_rejected(self, min_base_cfg):
|
||||||
|
cfg = min_base_cfg | DictDefault(
|
||||||
|
optimizer="muon",
|
||||||
|
fsdp_version=1,
|
||||||
|
fsdp_config={"reshard_after_forward": True},
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Muon optimizer is only compatible with FSDP2"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rl",
|
"rl",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user