Compare commits
13 Commits
fix/cp-was
...
tensorboar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
598c965043 | ||
|
|
a96733930e | ||
|
|
6130e40c37 | ||
|
|
5b2e3f00ce | ||
|
|
fc3b3d1d4e | ||
|
|
c9df6efdc2 | ||
|
|
0ee98a0309 | ||
|
|
2c05847a5f | ||
|
|
b0294b3427 | ||
|
|
1bcfc08c90 | ||
|
|
5a5cf30b26 | ||
|
|
7ddfb2d8a0 | ||
|
|
c57acef2c7 |
@@ -128,11 +128,9 @@ quartodoc:
|
|||||||
- monkeypatch.mistral_attn_hijack_flash
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
- monkeypatch.multipack
|
- monkeypatch.multipack
|
||||||
- monkeypatch.relora
|
- monkeypatch.relora
|
||||||
- monkeypatch.llama_expand_mask
|
|
||||||
- monkeypatch.lora_kernels
|
- monkeypatch.lora_kernels
|
||||||
- monkeypatch.utils
|
- monkeypatch.utils
|
||||||
- monkeypatch.btlm_attn_hijack_flash
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
- monkeypatch.llama_patch_multipack
|
|
||||||
- monkeypatch.stablelm_attn_hijack_flash
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
- monkeypatch.trainer_fsdp_optim
|
- monkeypatch.trainer_fsdp_optim
|
||||||
- monkeypatch.transformers_fa_utils
|
- monkeypatch.transformers_fa_utils
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
set -o pipefail
|
||||||
|
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
# hf download "microsoft/Phi-4-reasoning"
|
# hf download "microsoft/Phi-4-reasoning"
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ coverage:
|
|||||||
only_pulls: false
|
only_pulls: false
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
|
informational: true
|
||||||
|
|
||||||
parsers:
|
parsers:
|
||||||
gcov:
|
gcov:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Gradient Checkpointing and Activation Offloading
|
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
||||||
---
|
---
|
||||||
|
|
||||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||||
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
|
|||||||
|
|
||||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||||
|
|
||||||
|
### Enabling Layer Offloading
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
layer_offloading: true
|
||||||
|
```
|
||||||
|
|
||||||
|
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
||||||
|
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
||||||
|
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
||||||
|
trainable adapter weights stay on GPU permanently.
|
||||||
|
|
||||||
|
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
||||||
|
|
||||||
|
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
||||||
|
prefetched asynchronously on a separate CUDA stream for overlap.
|
||||||
|
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
||||||
|
previous layer is prefetched.
|
||||||
|
|
||||||
|
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
||||||
|
|
||||||
|
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
||||||
|
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
||||||
|
that is kept on GPU at any given time.
|
||||||
|
|
||||||
|
**Requirements:**
|
||||||
|
|
||||||
|
- CUDA GPU (CPU-only training is not supported for this feature)
|
||||||
|
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
||||||
|
- Best combined with LoRA/QLoRA where most parameters are frozen
|
||||||
|
|||||||
@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
|
|||||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||||
|
|
||||||
|
### Layer Offloading
|
||||||
|
|
||||||
|
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
||||||
|
|
||||||
|
- **Config:** `layer_offloading: true`
|
||||||
|
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
||||||
|
|
||||||
### Cut Cross Entropy (CCE)
|
### Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||||
|
|||||||
@@ -6,9 +6,6 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
Note: Training this model requires weights in BF16 which we will link to later.
|
|
||||||
Users interested in training can convert / descale the existing FP8 weights.
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
|
|||||||
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-122B-A10B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
# Regex matching to target shared experts too
|
||||||
|
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# Target experts
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: true
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
#lora_target_parameters:
|
# Regex matching to target shared experts too
|
||||||
|
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# Target experts
|
||||||
|
# lora_target_parameters:
|
||||||
# - mlp.experts.gate_up_proj
|
# - mlp.experts.gate_up_proj
|
||||||
# - mlp.experts.down_proj
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
@@ -52,7 +56,6 @@ learning_rate: 0.0002
|
|||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
lora_mlp_kernel: false
|
||||||
lora_qkv_kernel: false
|
lora_qkv_kernel: false
|
||||||
lora_o_kernel: false
|
lora_o_kernel: false
|
||||||
|
|||||||
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-27B
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
# Uncomment below to also target the linear attention projections.
|
||||||
|
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||||
|
# - linear_attn.in_proj_qkv
|
||||||
|
# - linear_attn.in_proj_z
|
||||||
|
# - linear_attn.out_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
base_model: Qwen/Qwen3.5-27B
|
base_model: Qwen/Qwen3.5-27B
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
|
|
||||||
# the text-only path. For multimodal (image+text) fine-tuning, add image
|
|
||||||
# columns to your dataset following axolotl's multimodal dataset format.
|
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-35B-A3B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Regex matching to target shared experts too
|
||||||
|
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# Target experts
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: true
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
#lora_target_parameters:
|
# Regex matching to target shared experts too
|
||||||
|
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# Target experts
|
||||||
|
# lora_target_parameters:
|
||||||
# - mlp.experts.gate_up_proj
|
# - mlp.experts.gate_up_proj
|
||||||
# - mlp.experts.down_proj
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,6 @@ lora_r: 32
|
|||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
# Targets the language model attention and MLP layers.
|
# Targets the language model attention and MLP layers.
|
||||||
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
|
|
||||||
# the same transformer stack, so standard attention targets work for both modalities.
|
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
|
|||||||
@@ -2,20 +2,6 @@
|
|||||||
|
|
||||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||||
|
|
||||||
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
|
|
||||||
|
|
||||||
Available configs:
|
|
||||||
|
|
||||||
| Config | Model | Type | Peak VRAM |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
|
|
||||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
|
|
||||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
|
||||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
|
||||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
|
||||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
|
||||||
|
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
@@ -23,43 +9,69 @@ Available configs:
|
|||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||||
|
```
|
||||||
|
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||||
|
|
||||||
|
4. Pick any config from the table below and run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/qwen3.5/<config>.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Available configs:
|
||||||
|
|
||||||
|
| Config | Model | Type | Peak VRAM |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||||
|
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||||
|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
|
||||||
|
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
|
||||||
|
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
|
||||||
|
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||||
|
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
|
||||||
|
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||||
|
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
|
||||||
|
|
||||||
|
### Gated DeltaNet Linear Attention
|
||||||
|
|
||||||
|
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_modules:
|
||||||
|
# ... standard projections ...
|
||||||
|
- linear_attn.in_proj_qkv
|
||||||
|
- linear_attn.in_proj_z
|
||||||
|
- linear_attn.out_proj
|
||||||
```
|
```
|
||||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
|
||||||
|
|
||||||
4. Run a finetuning example:
|
### Routed Experts (MoE)
|
||||||
|
|
||||||
```bash
|
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
|
||||||
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
|
||||||
axolotl train examples/qwen3.5/27b-qlora.yaml
|
|
||||||
|
|
||||||
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
|
```yaml
|
||||||
axolotl train examples/qwen3.5/27b-fft.yaml
|
lora_target_parameters:
|
||||||
|
- mlp.experts.gate_up_proj
|
||||||
|
- mlp.experts.down_proj
|
||||||
|
# - mlp.gate.weight # router
|
||||||
|
```
|
||||||
|
|
||||||
# MoE 35B-A3B text-only (QLoRA)
|
### Shared Experts (MoE)
|
||||||
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
|
||||||
|
|
||||||
# MoE 122B-A10B text-only (QLoRA)
|
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
||||||
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
|
||||||
|
|
||||||
# 9B vision+text (LoRA, multimodal dataset)
|
|
||||||
axolotl train examples/qwen3.5/9b-lora-vision.yaml
|
|
||||||
|
|
||||||
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
|
|
||||||
axolotl train examples/qwen3.5/9b-fft-vision.yaml
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
```
|
```
|
||||||
|
|
||||||
### TIPS
|
### TIPS
|
||||||
|
|
||||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
- For inference hyp, please see the respective model card details.
|
||||||
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
|
|
||||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||||
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
|
|||||||
@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
|
|||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
docstring-code-format = false
|
docstring-code-format = false
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "-m 'not slow'"
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
axolotl = ["huggingface_hub"]
|
||||||
|
|||||||
13
setup.py
13
setup.py
@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
|
|||||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (major, minor) >= (2, 9):
|
if (major, minor) >= (2, 10):
|
||||||
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
|
"fbgemm-gpu==1.5.0",
|
||||||
|
"fbgemm-gpu-genai==1.5.0",
|
||||||
|
]
|
||||||
|
if not install_xformers:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
||||||
|
elif (major, minor) >= (2, 9):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
"fbgemm-gpu==1.4.0",
|
"fbgemm-gpu==1.4.0",
|
||||||
"fbgemm-gpu-genai==1.4.2",
|
"fbgemm-gpu-genai==1.4.2",
|
||||||
]
|
]
|
||||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
|
||||||
if not install_xformers:
|
if not install_xformers:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpcore
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except HTTPError:
|
except (HTTPError, httpcore.ConnectError):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
adam_kwargs["eps"] = (eps1, eps2)
|
adam_kwargs["eps"] = (eps1, eps2)
|
||||||
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "flash_adamw":
|
||||||
|
from flashoptim import FlashAdamW
|
||||||
|
|
||||||
|
optimizer_cls = FlashAdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "flash_adam":
|
||||||
|
from flashoptim import FlashAdam
|
||||||
|
|
||||||
|
optimizer_cls = FlashAdam
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "flash_sgd":
|
||||||
|
from flashoptim import FlashSGD
|
||||||
|
|
||||||
|
optimizer_cls = FlashSGD
|
||||||
|
elif self.cfg.optimizer == "flash_sgdw":
|
||||||
|
from flashoptim import FlashSGDW
|
||||||
|
|
||||||
|
optimizer_cls = FlashSGDW
|
||||||
|
elif self.cfg.optimizer == "flash_lion":
|
||||||
|
from flashoptim import FlashLion
|
||||||
|
|
||||||
|
optimizer_cls = FlashLion
|
||||||
|
if "betas" in adam_kwargs:
|
||||||
|
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||||
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
|
if self.cfg.layer_offloading:
|
||||||
|
training_args_kwargs["layer_offloading"] = True
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
# don't use the HF gradient checkpointing, manually wrap
|
# don't use the HF gradient checkpointing, manually wrap
|
||||||
training_args_kwargs["gradient_checkpointing"] = False
|
training_args_kwargs["gradient_checkpointing"] = False
|
||||||
|
|||||||
@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
if (
|
||||||
|
self.cfg.adapter
|
||||||
|
and self.peft_config
|
||||||
|
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||||
|
):
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
|
|||||||
@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
|||||||
from trl.experimental.utils import pad_to_length
|
from trl.experimental.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
DistributedParallelMixin,
|
DistributedParallelMixin,
|
||||||
|
LayerOffloadingMixin,
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
PackingMixin,
|
PackingMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TOKENS_STATE_FILE = "tokens_state."
|
|
||||||
|
|
||||||
REDUCTION_FNS = {
|
REDUCTION_FNS = {
|
||||||
"mean": torch.mean,
|
"mean": torch.mean,
|
||||||
"min": torch.min,
|
"min": torch.min,
|
||||||
@@ -67,6 +67,7 @@ class AxolotlTrainer(
|
|||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
|
LayerOffloadingMixin,
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
DistributedParallelMixin,
|
DistributedParallelMixin,
|
||||||
Trainer,
|
Trainer,
|
||||||
|
|||||||
1
src/axolotl/core/trainers/constants.py
Normal file
1
src/axolotl/core/trainers/constants.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
TOKENS_STATE_FILE = "tokens_state.json"
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
Axolotl specific DPO args
|
Axolotl specific DPO args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
|
|
||||||
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_norm_loss: bool | None = False
|
dpo_norm_loss: bool | None = False
|
||||||
|
rpo_alpha: Optional[float] = field(default=None)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
from .activation_checkpointing import ActivationOffloadingMixin
|
from .activation_checkpointing import ActivationOffloadingMixin
|
||||||
from .checkpoints import CheckpointSaveMixin
|
from .checkpoints import CheckpointSaveMixin
|
||||||
|
from .layer_offloading import LayerOffloadingMixin
|
||||||
from .distributed_parallel import DistributedParallelMixin
|
from .distributed_parallel import DistributedParallelMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .packing import PackingMixin
|
from .packing import PackingMixin
|
||||||
|
|||||||
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""
|
||||||
|
Trainer mixin for layer-wise parameter offloading to CPU.
|
||||||
|
|
||||||
|
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
|
||||||
|
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
|
||||||
|
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
|
||||||
|
|
||||||
|
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
|
||||||
|
transfer stream), post-hook offloads layer N-1's frozen params.
|
||||||
|
Backward: same in reverse order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
|
||||||
|
"""Recursively search the model for the decoder layer ModuleList.
|
||||||
|
|
||||||
|
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
|
||||||
|
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
|
||||||
|
where layers are at model.language_model.layers).
|
||||||
|
"""
|
||||||
|
# BFS to find the first ModuleList containing decoder layers
|
||||||
|
queue = [model]
|
||||||
|
while queue:
|
||||||
|
m = queue.pop(0)
|
||||||
|
for _name, child in m.named_children():
|
||||||
|
if isinstance(child, nn.ModuleList) and len(child) > 0:
|
||||||
|
first_type = type(child[0]).__name__
|
||||||
|
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
|
||||||
|
layer_types = list({type(layer).__name__ for layer in child})
|
||||||
|
return child, layer_types
|
||||||
|
else:
|
||||||
|
queue.append(child)
|
||||||
|
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
|
||||||
|
"""Get all non-trainable parameters in a layer."""
|
||||||
|
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
|
||||||
|
|
||||||
|
|
||||||
|
class LayerOffloadManager:
|
||||||
|
"""Manages offloading frozen decoder layer params to CPU and streaming
|
||||||
|
them back during forward/backward with CUDA stream overlap.
|
||||||
|
|
||||||
|
Only frozen (requires_grad=False) parameters are offloaded.
|
||||||
|
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
num_prefetch: int = 1,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.num_prefetch = num_prefetch
|
||||||
|
self._hooks: list = []
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
# Find decoder layers
|
||||||
|
self.layers, layer_types = _find_decoder_layers(model)
|
||||||
|
if self.layers is None:
|
||||||
|
LOG.warning(
|
||||||
|
"LayerOffloadManager: no decoder layers found, offloading disabled"
|
||||||
|
)
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self.enabled = True
|
||||||
|
self.n_layers = len(self.layers)
|
||||||
|
LOG.info(
|
||||||
|
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine GPU device
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.device.type == "cuda":
|
||||||
|
self._device = p.device
|
||||||
|
break
|
||||||
|
if self._device is None:
|
||||||
|
LOG.warning("LayerOffloadManager: no CUDA parameters found")
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
# Transfer stream for async prefetch
|
||||||
|
self._transfer_stream = torch.cuda.Stream(device=self._device)
|
||||||
|
|
||||||
|
# Track which layers have their frozen params on GPU
|
||||||
|
self._on_gpu: set[int] = set(range(self.n_layers))
|
||||||
|
|
||||||
|
# Cache: frozen param references per layer (list of (name, param) tuples)
|
||||||
|
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
|
||||||
|
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# CPU storage: pinned tensors for each layer's frozen params
|
||||||
|
# Populated on first offload
|
||||||
|
self._cpu_data: list[dict[str, torch.Tensor]] = [
|
||||||
|
{} for _ in range(self.n_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Offload all layers upfront
|
||||||
|
self._offload_all()
|
||||||
|
|
||||||
|
# Release cached memory blocks back to the driver
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _offload_all(self):
|
||||||
|
"""Move all frozen params in all decoder layers to CPU."""
|
||||||
|
mem_before = torch.cuda.memory_allocated(self._device)
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
self._offload_layer(i)
|
||||||
|
mem_after = torch.cuda.memory_allocated(self._device)
|
||||||
|
freed = (mem_before - mem_after) / 1e6
|
||||||
|
LOG.info(
|
||||||
|
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
|
||||||
|
f"freed {freed:.0f} MB GPU memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _offload_layer(self, idx: int):
|
||||||
|
"""Move frozen params of layer idx to CPU pinned memory."""
|
||||||
|
if idx not in self._on_gpu:
|
||||||
|
return
|
||||||
|
for name, param in self._frozen_params[idx]:
|
||||||
|
if param.device.type != "cuda":
|
||||||
|
continue
|
||||||
|
# Allocate pinned CPU tensor on first offload
|
||||||
|
if name not in self._cpu_data[idx]:
|
||||||
|
self._cpu_data[idx][name] = torch.empty_like(
|
||||||
|
param.data, device="cpu", pin_memory=True
|
||||||
|
)
|
||||||
|
cpu_buf = self._cpu_data[idx][name]
|
||||||
|
# Async copy GPU -> CPU (on transfer stream for overlap)
|
||||||
|
cpu_buf.copy_(param.data, non_blocking=True)
|
||||||
|
# Point parameter at a dummy CPU tensor to free GPU memory
|
||||||
|
param.data = cpu_buf
|
||||||
|
self._on_gpu.discard(idx)
|
||||||
|
|
||||||
|
def _load_layer(self, idx: int, stream=None):
|
||||||
|
"""Move frozen params of layer idx back to GPU."""
|
||||||
|
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||||
|
return
|
||||||
|
ctx = (
|
||||||
|
torch.cuda.stream(stream)
|
||||||
|
if stream is not None
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
)
|
||||||
|
with ctx:
|
||||||
|
for _name, param in self._frozen_params[idx]:
|
||||||
|
if param.device.type == "cuda":
|
||||||
|
continue
|
||||||
|
gpu_data = param.data.to(self._device, non_blocking=True)
|
||||||
|
param.data = gpu_data
|
||||||
|
self._on_gpu.add(idx)
|
||||||
|
|
||||||
|
def _prefetch_layer(self, idx: int):
|
||||||
|
"""Async prefetch layer idx on the transfer stream."""
|
||||||
|
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||||
|
return
|
||||||
|
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
|
||||||
|
self._load_layer(idx, stream=self._transfer_stream)
|
||||||
|
|
||||||
|
def _wait_transfer(self):
|
||||||
|
"""Make default stream wait for any in-flight transfers."""
|
||||||
|
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
|
||||||
|
|
||||||
|
def setup_hooks(self):
|
||||||
|
"""Register forward and backward hooks on each decoder layer."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
for idx in range(self.n_layers):
|
||||||
|
layer = self.layers[idx]
|
||||||
|
|
||||||
|
def make_pre_fwd(i):
|
||||||
|
def hook(module, args):
|
||||||
|
# Ensure this layer is on GPU
|
||||||
|
if i not in self._on_gpu:
|
||||||
|
self._load_layer(i)
|
||||||
|
self._wait_transfer()
|
||||||
|
# Prefetch next layer(s)
|
||||||
|
for offset in range(1, self.num_prefetch + 1):
|
||||||
|
self._prefetch_layer(i + offset)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def make_post_fwd(i):
|
||||||
|
def hook(module, args, output):
|
||||||
|
# Offload previous layer (no longer needed in forward)
|
||||||
|
if i > 0:
|
||||||
|
self._offload_layer(i - 1)
|
||||||
|
# Offload last layer after forward
|
||||||
|
if i == self.n_layers - 1:
|
||||||
|
self._offload_layer(i)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def make_pre_bwd(i):
|
||||||
|
def hook(module, grad_output):
|
||||||
|
# Load this layer for backward
|
||||||
|
if i not in self._on_gpu:
|
||||||
|
self._load_layer(i)
|
||||||
|
self._wait_transfer()
|
||||||
|
# Prefetch previous layer(s)
|
||||||
|
for offset in range(1, self.num_prefetch + 1):
|
||||||
|
self._prefetch_layer(i - offset)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def make_post_bwd(i):
|
||||||
|
def hook(module, grad_input, grad_output):
|
||||||
|
# Offload the layer above
|
||||||
|
if i < self.n_layers - 1:
|
||||||
|
self._offload_layer(i + 1)
|
||||||
|
# Offload first layer after backward
|
||||||
|
if i == 0:
|
||||||
|
self._offload_layer(i)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
|
||||||
|
h2 = layer.register_forward_hook(make_post_fwd(idx))
|
||||||
|
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
|
||||||
|
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
|
||||||
|
self._hooks.extend([h1, h2, h3, h4])
|
||||||
|
|
||||||
|
def remove_hooks(self):
|
||||||
|
"""Remove all hooks and restore layers to GPU."""
|
||||||
|
for h in self._hooks:
|
||||||
|
h.remove()
|
||||||
|
self._hooks.clear()
|
||||||
|
if self.enabled:
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
if i not in self._on_gpu:
|
||||||
|
self._load_layer(i)
|
||||||
|
|
||||||
|
def pre_step(self):
|
||||||
|
"""Called before each training step — ensure layers start offloaded."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
for i in list(self._on_gpu):
|
||||||
|
self._offload_layer(i)
|
||||||
|
# Prefetch layer 0 for forward
|
||||||
|
self._prefetch_layer(0)
|
||||||
|
|
||||||
|
def post_step(self):
|
||||||
|
"""Called after each training step — ensure layers are offloaded."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
for i in list(self._on_gpu):
|
||||||
|
self._offload_layer(i)
|
||||||
|
# Prefetch layer 0 for next step
|
||||||
|
self._prefetch_layer(0)
|
||||||
|
|
||||||
|
|
||||||
|
class _LayerOffloadContext:
|
||||||
|
"""Context manager wrapping pre_step / post_step around a training step."""
|
||||||
|
|
||||||
|
def __init__(self, manager: LayerOffloadManager):
|
||||||
|
self.manager = manager
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.manager.pre_step()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.manager.post_step()
|
||||||
|
|
||||||
|
|
||||||
|
class LayerOffloadingMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Trainer mixin class for layer-wise parameter offloading to CPU.
|
||||||
|
|
||||||
|
Offloads frozen decoder layer params to CPU at init, then streams them
|
||||||
|
on/off GPU one layer at a time during each training step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if getattr(self.args, "layer_offloading", False):
|
||||||
|
LOG.info("Layer parameter offloading enabled")
|
||||||
|
self._layer_offload_manager = LayerOffloadManager(
|
||||||
|
model=self.model,
|
||||||
|
num_prefetch=1,
|
||||||
|
)
|
||||||
|
self._layer_offload_manager.setup_hooks()
|
||||||
|
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
|
||||||
|
else:
|
||||||
|
self._layer_offload_manager = None
|
||||||
|
self._layer_offload_ctx = contextlib.nullcontext()
|
||||||
|
|
||||||
|
def training_step(self, *args, **kwargs):
|
||||||
|
with self._layer_offload_ctx:
|
||||||
|
return super().training_step(*args, **kwargs)
|
||||||
@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
|
|||||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
layer_offloading: bool | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# multi-modal section
|
# multi-modal section
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||||
|
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
|
||||||
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||||
@@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
|
|||||||
|
|
||||||
cls_names = entry if isinstance(entry, list) else [entry]
|
cls_names = entry if isinstance(entry, list) else [entry]
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
module = importlib.import_module(module_path)
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
|
||||||
|
if model_type.endswith("_text"):
|
||||||
|
parent_type = model_type.removesuffix("_text")
|
||||||
|
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
classes = []
|
classes = []
|
||||||
for cls_name in cls_names:
|
for cls_name in cls_names:
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128}
|
BLOCK_M: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64, 128, 256}
|
BLOCK_N: {32, 64}
|
||||||
BLOCK_K: {32, 64, 128}
|
BLOCK_K: {32, 64, 128}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
@@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
configs = []
|
configs = []
|
||||||
for block_m, block_n, block_k, warps, stages in product(
|
for block_m, block_n, block_k, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_N
|
[32, 64], # BLOCK_N
|
||||||
[32, 64, 128], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
@@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128} (token tile)
|
BLOCK_M: {32, 64, 128} (token tile)
|
||||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
BLOCK_K: {32, 64, 128} (output tile)
|
||||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
BLOCK_N: {32, 64} (reduction tile)
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
[32, 64, 128], # BLOCK_K (output dimension)
|
||||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
[32, 64], # BLOCK_N (reduction dimension)
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
@@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs():
|
|||||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
BLOCK_M: {32, 64, 128} (token-loop tile)
|
||||||
BLOCK_K: {32, 64, 128, 256}
|
BLOCK_K: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64, 128, 256}
|
BLOCK_N: {32, 64}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
|
|
||||||
@@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs():
|
|||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128, 256], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[32, 64, 128, 256], # BLOCK_N
|
[32, 64], # BLOCK_N
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -571,15 +571,6 @@ class PatchManager:
|
|||||||
LOG.info("Patching with xformers attention...")
|
LOG.info("Patching with xformers attention...")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
|
|
||||||
def _patch_llama_sample_packing(self):
|
|
||||||
"""Apply sample packing patches for LLaMA models."""
|
|
||||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
|
||||||
hijack_llama_prepare_4d_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
|
|
||||||
hijack_llama_prepare_4d_mask()
|
|
||||||
|
|
||||||
def _patch_llama_derived_model(self):
|
def _patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block."""
|
"""Modify all llama derived models in one block."""
|
||||||
if self.cfg.is_llama_derived_model and not (
|
if self.cfg.is_llama_derived_model and not (
|
||||||
@@ -591,8 +582,6 @@ class PatchManager:
|
|||||||
self._patch_llama_flash_attention()
|
self._patch_llama_flash_attention()
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
elif self.cfg.sample_packing:
|
|
||||||
self._patch_llama_sample_packing()
|
|
||||||
elif self.cfg.s2_attention:
|
elif self.cfg.s2_attention:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
|
|||||||
@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||||
|
|
||||||
|
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||||
|
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
LOG.warning(
|
||||||
|
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||||
|
tokenizer.eos_token,
|
||||||
|
)
|
||||||
|
|
||||||
additional_special_tokens = None
|
additional_special_tokens = None
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
special_tokens = cfg.special_tokens.to_dict()
|
special_tokens = cfg.special_tokens.to_dict()
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
||||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
|
||||||
inverted_mask = 1.0 - masked_zero_one_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hijack_expand_mask():
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
transformers.models.llama.modeling_llama._expand_mask = _expand_mask
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import (
|
|
||||||
patched_prepare_4d_causal_attention_mask,
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_prepare_4d_mask():
|
|
||||||
from transformers import modeling_attn_mask_utils
|
|
||||||
from transformers.models.llama import modeling_llama
|
|
||||||
|
|
||||||
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
|
||||||
)
|
|
||||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
|
||||||
)
|
|
||||||
modeling_llama._prepare_4d_causal_attention_mask = (
|
|
||||||
patched_prepare_4d_causal_attention_mask
|
|
||||||
)
|
|
||||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
|
|
||||||
patched_prepare_4d_causal_attention_mask
|
|
||||||
)
|
|
||||||
@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
|
|||||||
setattr(parent, child_name, value)
|
setattr(parent, child_name, value)
|
||||||
|
|
||||||
|
|
||||||
def mask_2d_to_4d(
|
|
||||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
|
||||||
when they attend to each other within that sequence.
|
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
|
||||||
|
|
||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
||||||
binary_mask = torch.where(
|
|
||||||
mask != 0,
|
|
||||||
torch.tensor(1, device=mask.device).to(dtype),
|
|
||||||
torch.tensor(0, device=mask.device).to(dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a block-diagonal mask.
|
|
||||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
|
||||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
|
||||||
|
|
||||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
|
||||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
|
||||||
mask.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
|
||||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
|
||||||
|
|
||||||
return masked_zero_one_mask
|
|
||||||
|
|
||||||
|
|
||||||
def patched_prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
||||||
return _prepare_4d_causal_attention_mask(
|
|
||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
||||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
|||||||
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Synthetic dataset generator for benchmarking and testing.
|
||||||
|
|
||||||
|
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
||||||
|
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
||||||
|
weighted dataset mixes.
|
||||||
|
|
||||||
|
YAML configuration example:
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: synthetic
|
||||||
|
type: _synthetic
|
||||||
|
length: 1000
|
||||||
|
sequence_length: 2048
|
||||||
|
min_input_id: 100
|
||||||
|
max_input_id: 32000
|
||||||
|
seed: 42
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
||||||
|
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_length: int = 2048,
|
||||||
|
length: int = 1000,
|
||||||
|
min_input_id: int = 100,
|
||||||
|
max_input_id: int = 32000,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.length = length
|
||||||
|
self.min_input_id = min_input_id
|
||||||
|
self.max_input_id = max_input_id
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def wrap_dataset(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
process_count: int | None = None,
|
||||||
|
keep_in_memory: bool | None = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dataset:
|
||||||
|
LOG.info(
|
||||||
|
f"Generating synthetic dataset: {self.length} samples, "
|
||||||
|
f"sequence_length={self.sequence_length}, "
|
||||||
|
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(self.seed)
|
||||||
|
input_ids = rng.integers(
|
||||||
|
low=self.min_input_id,
|
||||||
|
high=self.max_input_id,
|
||||||
|
size=(self.length, self.sequence_length),
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
attention_mask = [[1] * self.sequence_length] * self.length
|
||||||
|
# labels == input_ids means we train on all tokens
|
||||||
|
labels = [row[:] for row in input_ids]
|
||||||
|
|
||||||
|
return Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
ds_cfg = ds_cfg or {}
|
||||||
|
|
||||||
|
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
||||||
|
length = ds_cfg.get("length", 1000)
|
||||||
|
min_input_id = ds_cfg.get("min_input_id", 100)
|
||||||
|
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
||||||
|
seed = ds_cfg.get("seed", None)
|
||||||
|
|
||||||
|
return SyntheticDatasetStrategy(
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
length=length,
|
||||||
|
min_input_id=min_input_id,
|
||||||
|
max_input_id=max_input_id,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
|
|||||||
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||||
model, peft_config = model_loader.load()
|
model, peft_config = model_loader.load()
|
||||||
if model.generation_config is not None:
|
if getattr(model, "generation_config", None) is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
model_properties = model.config.to_dict()
|
model_properties = model.config.to_dict()
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
|
|||||||
if (
|
if (
|
||||||
isinstance(mod, FakeQuantizedLinear)
|
isinstance(mod, FakeQuantizedLinear)
|
||||||
and mod.activation_fake_quantizer is not None
|
and mod.activation_fake_quantizer is not None
|
||||||
|
and hasattr(mod.activation_fake_quantizer, "enabled")
|
||||||
):
|
):
|
||||||
mod.activation_fake_quantizer.enabled = enable
|
mod.activation_fake_quantizer.enabled = enable
|
||||||
mod.weight_fake_quantizer.enabled = enable
|
if hasattr(mod.weight_fake_quantizer, "enabled"):
|
||||||
|
mod.weight_fake_quantizer.enabled = enable
|
||||||
|
|
||||||
|
|
||||||
class QATCallback(TrainerCallback):
|
class QATCallback(TrainerCallback):
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TOKENS_STATE_FILE = "tokens_state.json"
|
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondCallback(TrainerCallback):
|
class TokensPerSecondCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
|
|||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
from axolotl.utils.schemas.datasets import (
|
||||||
|
DPODataset,
|
||||||
|
KTODataset,
|
||||||
|
SFTDataset,
|
||||||
|
SyntheticDataset,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -308,6 +313,14 @@ def validate_config(
|
|||||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||||
|
elif (
|
||||||
|
ds_cfg.get("type")
|
||||||
|
if isinstance(ds_cfg, dict)
|
||||||
|
else getattr(ds_cfg, "type", None)
|
||||||
|
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
||||||
|
cfg["datasets"][idx] = SyntheticDataset(
|
||||||
|
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
||||||
|
)
|
||||||
elif not isinstance(ds_cfg, SFTDataset):
|
elif not isinstance(ds_cfg, SFTDataset):
|
||||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
|
|||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
# Load the dataset
|
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
||||||
dataset = load_dataset_with_config(
|
if dataset_config.type == "_synthetic":
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
dataset = Dataset.from_dict({"text": [""]})
|
||||||
)
|
else:
|
||||||
|
# Load the dataset
|
||||||
|
dataset = load_dataset_with_config(
|
||||||
|
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
# Parse dataset type
|
# Parse dataset type
|
||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
|
|||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
)
|
)
|
||||||
|
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
|
Int4WeightOnlyConfig,
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
Int8DynamicActivationInt4WeightConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -173,6 +175,70 @@ def quantize_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_qat_config(
|
||||||
|
base_config: AOBaseConfig,
|
||||||
|
weight_dtype: TorchAOQuantDType,
|
||||||
|
activation_dtype: TorchAOQuantDType | None,
|
||||||
|
group_size: int | None,
|
||||||
|
) -> QATConfig:
|
||||||
|
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
|
||||||
|
group_size and other params are properly propagated (torchao's QATConfig(base_config)
|
||||||
|
does not always map these correctly)."""
|
||||||
|
from torchao.quantization.qat.fake_quantize_config import (
|
||||||
|
Float8FakeQuantizeConfig,
|
||||||
|
IntxFakeQuantizeConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||||
|
return QATConfig(
|
||||||
|
activation_config=base_config,
|
||||||
|
weight_config=base_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build explicit weight config
|
||||||
|
weight_fq_config: (
|
||||||
|
Int4WeightFakeQuantizeConfig
|
||||||
|
| IntxFakeQuantizeConfig
|
||||||
|
| Float8FakeQuantizeConfig
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
if weight_dtype == TorchAOQuantDType.int4:
|
||||||
|
gs = (
|
||||||
|
group_size
|
||||||
|
if group_size is not None
|
||||||
|
else getattr(base_config, "group_size", 128)
|
||||||
|
)
|
||||||
|
activation_dt = None
|
||||||
|
if activation_dtype == TorchAOQuantDType.int8:
|
||||||
|
activation_dt = torch.bfloat16
|
||||||
|
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||||
|
activation_dt = torch.float8_e4m3fn
|
||||||
|
kwargs = {"group_size": gs}
|
||||||
|
if activation_dt is not None:
|
||||||
|
kwargs["activation_dtype"] = activation_dt
|
||||||
|
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
|
||||||
|
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||||
|
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Build explicit activation config
|
||||||
|
activation_fq_config = None
|
||||||
|
if activation_dtype == TorchAOQuantDType.int8:
|
||||||
|
activation_fq_config = IntxFakeQuantizeConfig(
|
||||||
|
dtype=torch.int8, granularity="per_token", is_symmetric=False
|
||||||
|
)
|
||||||
|
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||||
|
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if weight_fq_config is not None:
|
||||||
|
return QATConfig(
|
||||||
|
weight_config=weight_fq_config,
|
||||||
|
activation_config=activation_fq_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to base_config for unhandled combos
|
||||||
|
return QATConfig(base_config)
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_qat(
|
def prepare_model_for_qat(
|
||||||
model,
|
model,
|
||||||
weight_dtype: TorchAOQuantDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=activation_dtype,
|
activation_dtype=activation_dtype,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
qat_config = _make_qat_config(
|
||||||
qat_config = QATConfig(
|
base_config, weight_dtype, activation_dtype, group_size
|
||||||
activation_config=base_config,
|
)
|
||||||
weight_config=base_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qat_config = QATConfig(base_config)
|
|
||||||
quantize_(model, qat_config)
|
quantize_(model, qat_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
# activation fake quantization is not supported for embedding layers
|
# activation fake quantization is not supported for embedding layers
|
||||||
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
embedding_qat_config = _make_qat_config(
|
||||||
embedding_qat_config = QATConfig(
|
embedding_base_config, weight_dtype, None, group_size
|
||||||
weight_config=embedding_base_config,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_qat_config = QATConfig(embedding_base_config)
|
|
||||||
quantize_(
|
quantize_(
|
||||||
model,
|
model,
|
||||||
embedding_qat_config,
|
embedding_qat_config,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
|
|||||||
return [lr * scale for lr in original]
|
return [lr * scale for lr in original]
|
||||||
|
|
||||||
return original * scale
|
return original * scale
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable state, saving inner_schedule as its own state_dict."""
|
||||||
|
state = {
|
||||||
|
key: value
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if key not in ("optimizer", "inner_schedule")
|
||||||
|
}
|
||||||
|
state["inner_schedule_state"] = self.inner_schedule.state_dict()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||||
|
"""Restore state, including inner_schedule."""
|
||||||
|
inner_state = state_dict.pop("inner_schedule_state")
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
self.inner_schedule.load_state_dict(inner_state)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
|
|||||||
PretrainingDataset,
|
PretrainingDataset,
|
||||||
SFTDataset,
|
SFTDataset,
|
||||||
StepwiseSupervisedDataset,
|
StepwiseSupervisedDataset,
|
||||||
|
SyntheticDataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||||
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
list[
|
||||||
|
SFTDataset
|
||||||
|
| DPODataset
|
||||||
|
| KTODataset
|
||||||
|
| StepwiseSupervisedDataset
|
||||||
|
| SyntheticDataset
|
||||||
|
],
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
test_datasets: (
|
test_datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
list[
|
||||||
|
SFTDataset
|
||||||
|
| DPODataset
|
||||||
|
| KTODataset
|
||||||
|
| StepwiseSupervisedDataset
|
||||||
|
| SyntheticDataset
|
||||||
|
],
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -433,6 +446,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
layer_offloading: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = Field(
|
unfrozen_parameters: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
|
|||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
class SyntheticDataset(BaseModel):
|
||||||
|
"""Synthetic dataset configuration for benchmarking and testing.
|
||||||
|
|
||||||
|
Generates datasets with configurable sequence length, dataset size, and token ID
|
||||||
|
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
||||||
|
validating weighted dataset mixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: Literal["synthetic"] = "synthetic"
|
||||||
|
type: Literal["_synthetic"] = "_synthetic"
|
||||||
|
length: int = Field(
|
||||||
|
default=1000,
|
||||||
|
json_schema_extra={"description": "Number of rows to generate"},
|
||||||
|
)
|
||||||
|
sequence_length: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Sequence length per row (defaults to sequence_len from config)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
min_input_id: int = Field(
|
||||||
|
default=100,
|
||||||
|
json_schema_extra={"description": "Minimum token ID for generation"},
|
||||||
|
)
|
||||||
|
max_input_id: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
seed: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Random seed for reproducibility"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = (
|
||||||
|
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
||||||
|
)
|
||||||
|
|||||||
@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
came_pytorch = "came_pytorch"
|
came_pytorch = "came_pytorch"
|
||||||
muon = "muon"
|
muon = "muon"
|
||||||
dion = "dion"
|
dion = "dion"
|
||||||
|
flash_adamw = "flash_adamw"
|
||||||
|
flash_adam = "flash_adam"
|
||||||
|
flash_sgd = "flash_sgd"
|
||||||
|
flash_sgdw = "flash_sgdw"
|
||||||
|
flash_lion = "flash_lion"
|
||||||
|
|
||||||
|
|
||||||
class RingAttnFunc(str, Enum):
|
class RingAttnFunc(str, Enum):
|
||||||
|
|||||||
@@ -790,6 +790,14 @@ class OptimizationValidationMixin:
|
|||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_fsdp_version(data):
|
||||||
|
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
|
||||||
|
fsdp_version = data.get("fsdp_version")
|
||||||
|
if fsdp_version is None:
|
||||||
|
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||||
|
return fsdp_version
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_muon_deepspeed_fsdp(cls, data):
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
@@ -799,15 +807,32 @@ class OptimizationValidationMixin:
|
|||||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||||
)
|
)
|
||||||
if data.get("fsdp") or data.get("fsdp_config"):
|
if data.get("fsdp") or data.get("fsdp_config"):
|
||||||
fsdp_version = data.get("fsdp_version")
|
fsdp_version = cls._resolve_fsdp_version(data)
|
||||||
if fsdp_version is None:
|
|
||||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
|
||||||
if str(fsdp_version) != "2":
|
if str(fsdp_version) != "2":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_flashoptim_deepspeed_fsdp(cls, data):
|
||||||
|
optimizer = data.get("optimizer") or ""
|
||||||
|
if str(optimizer).startswith("flash_"):
|
||||||
|
if data.get("deepspeed"):
|
||||||
|
raise ValueError(
|
||||||
|
f"{optimizer} optimizer is incompatible with DeepSpeed. "
|
||||||
|
"Flash optimizers only support DDP and FSDP2."
|
||||||
|
)
|
||||||
|
if data.get("fsdp") or data.get("fsdp_config"):
|
||||||
|
fsdp_version = cls._resolve_fsdp_version(data)
|
||||||
|
if str(fsdp_version) != "2":
|
||||||
|
raise ValueError(
|
||||||
|
f"{optimizer} optimizer is only compatible with FSDP2. "
|
||||||
|
"Set fsdp_version: 2 to use flash optimizers with FSDP."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_batch_flattening_fa(cls, data):
|
def check_batch_flattening_fa(cls, data):
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ import datasets
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import transformers.utils as _transformers_utils
|
||||||
|
import transformers.utils.import_utils as _import_utils
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
@@ -29,6 +31,26 @@ from tests.hf_offline_utils import (
|
|||||||
|
|
||||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
# Shim for deepseek v3
|
||||||
|
if not hasattr(_import_utils, "is_torch_fx_available"):
|
||||||
|
|
||||||
|
def _is_torch_fx_available():
|
||||||
|
try:
|
||||||
|
import torch.fx # noqa: F401 # pylint: disable=unused-import
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
_import_utils.is_torch_fx_available = _is_torch_fx_available
|
||||||
|
|
||||||
|
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
|
||||||
|
|
||||||
|
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
|
||||||
|
_is_flash_attn_gte("2.10")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ Test strategy:
|
|||||||
- Tolerances account for tf32 accumulation in Triton kernels
|
- Tolerances account for tf32 accumulation in Triton kernels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -34,6 +35,21 @@ pytestmark = pytest.mark.skipif(
|
|||||||
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
||||||
|
|
||||||
|
|
||||||
|
def skip_on_out_of_resources(func):
|
||||||
|
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
if "OutOfResources" in type(exc).__name__:
|
||||||
|
pytest.skip(f"GPU shared memory too small: {exc}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Helpers
|
# Helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -209,6 +225,7 @@ def make_test_data(
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestForwardPass:
|
class TestForwardPass:
|
||||||
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
||||||
|
|
||||||
@@ -288,6 +305,7 @@ class TestForwardPass:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestForwardGrouped:
|
class TestForwardGrouped:
|
||||||
"""Test forward pass with grouped_in/grouped_out configurations."""
|
"""Test forward pass with grouped_in/grouped_out configurations."""
|
||||||
|
|
||||||
@@ -377,6 +395,7 @@ class TestForwardGrouped:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestLoRAGradients:
|
class TestLoRAGradients:
|
||||||
"""Test backward LoRA gradient computation (dA, dB)."""
|
"""Test backward LoRA gradient computation (dA, dB)."""
|
||||||
|
|
||||||
@@ -452,6 +471,7 @@ class TestLoRAGradients:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestAutograd:
|
class TestAutograd:
|
||||||
"""Test full autograd integration through ScatterMoELoRA."""
|
"""Test full autograd integration through ScatterMoELoRA."""
|
||||||
|
|
||||||
@@ -620,6 +640,7 @@ class TestAutograd:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestBaseEquivalence:
|
class TestBaseEquivalence:
|
||||||
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
||||||
|
|
||||||
@@ -692,6 +713,7 @@ class TestBaseEquivalence:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestLoRAAdditivity:
|
class TestLoRAAdditivity:
|
||||||
"""Test that the LoRA component is correctly additive."""
|
"""Test that the LoRA component is correctly additive."""
|
||||||
|
|
||||||
@@ -749,6 +771,7 @@ class TestLoRAAdditivity:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestParallelExpertsModule:
|
class TestParallelExpertsModule:
|
||||||
"""Test the ParallelExperts module with LoRA."""
|
"""Test the ParallelExperts module with LoRA."""
|
||||||
|
|
||||||
@@ -816,6 +839,7 @@ class TestParallelExpertsModule:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestEdgeCases:
|
class TestEdgeCases:
|
||||||
"""Edge cases and boundary conditions."""
|
"""Edge cases and boundary conditions."""
|
||||||
|
|
||||||
@@ -913,6 +937,7 @@ class TestEdgeCases:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestFusedDX:
|
class TestFusedDX:
|
||||||
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
||||||
|
|
||||||
@@ -980,6 +1005,7 @@ class TestFusedDX:
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||||
|
|
||||||
|
@skip_on_out_of_resources
|
||||||
def test_large(self):
|
def test_large(self):
|
||||||
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||||
|
|
||||||
@@ -1122,6 +1148,7 @@ class TestFusedDX:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestFusedGatherBackward:
|
class TestFusedGatherBackward:
|
||||||
"""Test fused gather + backward dA/dB kernel."""
|
"""Test fused gather + backward dA/dB kernel."""
|
||||||
|
|
||||||
@@ -1174,6 +1201,7 @@ class TestFusedGatherBackward:
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||||
|
|
||||||
|
@skip_on_out_of_resources
|
||||||
def test_large(self):
|
def test_large(self):
|
||||||
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||||
|
|
||||||
@@ -1183,6 +1211,7 @@ class TestFusedGatherBackward:
|
|||||||
def test_k1(self):
|
def test_k1(self):
|
||||||
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
||||||
|
|
||||||
|
@skip_on_out_of_resources
|
||||||
def test_many_experts(self):
|
def test_many_experts(self):
|
||||||
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
||||||
|
|
||||||
@@ -1269,6 +1298,8 @@ class TestFusedGatherBackward:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.xfail(reason="flaky", strict=False)
|
||||||
class TestTokenRounding:
|
class TestTokenRounding:
|
||||||
"""Test token rounding utility and its integration with backward kernels."""
|
"""Test token rounding utility and its integration with backward kernels."""
|
||||||
|
|
||||||
@@ -1315,6 +1346,7 @@ class TestTokenRounding:
|
|||||||
)
|
)
|
||||||
prev = padded_offsets[e].item()
|
prev = padded_offsets[e].item()
|
||||||
|
|
||||||
|
@skip_on_out_of_resources
|
||||||
def test_round_with_fused_gather(self):
|
def test_round_with_fused_gather(self):
|
||||||
"""Token rounding + fused gather gives same result as plain fused gather."""
|
"""Token rounding + fused gather gives same result as plain fused gather."""
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@@ -1414,6 +1446,7 @@ class TestTokenRounding:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestCombinedOptimizations:
|
class TestCombinedOptimizations:
|
||||||
"""Test all optimizations together."""
|
"""Test all optimizations together."""
|
||||||
|
|
||||||
@@ -1583,6 +1616,7 @@ def _make_mock_sigmoid_moe_block(
|
|||||||
return moe_block, T, H, FF, E, K
|
return moe_block, T, H, FF, E, K
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestHFScatterMoESigmoidRouting:
|
class TestHFScatterMoESigmoidRouting:
|
||||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||||
|
|
||||||
@@ -1724,6 +1758,7 @@ class TestHFScatterMoESigmoidRouting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||||
|
|
||||||
|
|||||||
@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
|
|||||||
def _get_repo_path():
|
def _get_repo_path():
|
||||||
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
||||||
return (
|
return (
|
||||||
Path(__file__).parent.parent.parent
|
Path(__file__).parent.parent.parent.parent
|
||||||
/ "src"
|
/ "src"
|
||||||
/ "axolotl"
|
/ "axolotl"
|
||||||
/ "integrations"
|
/ "integrations"
|
||||||
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
|
|||||||
|
|
||||||
# Kernelize
|
# Kernelize
|
||||||
repo_path = (
|
repo_path = (
|
||||||
Path(__file__).parent.parent.parent
|
Path(__file__).parent.parent.parent.parent
|
||||||
/ "src"
|
/ "src"
|
||||||
/ "axolotl"
|
/ "axolotl"
|
||||||
/ "integrations"
|
/ "integrations"
|
||||||
|
|||||||
@@ -86,5 +86,5 @@ class TestPackedFlex:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def verify_training_success(temp_dir):
|
|||||||
event_file = os.path.join(tb_log_path, event_files[0])
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
df = reader.scalars
|
df = reader.scalars
|
||||||
train_loss_df = df[df.tag == "train/train_loss"]
|
train_loss_df = df[df.tag == "train/loss"]
|
||||||
if len(train_loss_df) > 0:
|
if len(train_loss_df) > 0:
|
||||||
final_loss = train_loss_df.value.values[-1]
|
final_loss = train_loss_df.value.values[-1]
|
||||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def verify_fp8_training_success(temp_dir):
|
|||||||
event_file = os.path.join(tb_log_path, event_files[0])
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
df = reader.scalars
|
df = reader.scalars
|
||||||
train_loss_df = df[df.tag == "train/train_loss"]
|
train_loss_df = df[df.tag == "train/loss"]
|
||||||
if len(train_loss_df) > 0:
|
if len(train_loss_df) > 0:
|
||||||
final_loss = train_loss_df.value.values[-1]
|
final_loss = train_loss_df.value.values[-1]
|
||||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
|
|||||||
event_file = os.path.join(tb_log_path, event_files[0])
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
df = reader.scalars
|
df = reader.scalars
|
||||||
train_loss_df = df[df.tag == "train/train_loss"]
|
train_loss_df = df[df.tag == "train/loss"]
|
||||||
if len(train_loss_df) > 0:
|
if len(train_loss_df) > 0:
|
||||||
final_loss = train_loss_df.value.values[-1]
|
final_loss = train_loss_df.value.values[-1]
|
||||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
|
|||||||
event_file = os.path.join(tb_log_path, event_files[0])
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
df = reader.scalars
|
df = reader.scalars
|
||||||
train_loss_df = df[df.tag == "train/train_loss"]
|
train_loss_df = df[df.tag == "train/loss"]
|
||||||
if len(train_loss_df) > 0:
|
if len(train_loss_df) > 0:
|
||||||
final_loss = train_loss_df.value.values[-1]
|
final_loss = train_loss_df.value.values[-1]
|
||||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
|||||||
@@ -94,5 +94,5 @@ class TestMultiGPUGemma3:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 1.8, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.8, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -156,7 +156,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
@@ -233,7 +233,7 @@ class TestMultiGPULlama:
|
|||||||
loss_threshold = 2.3
|
loss_threshold = 2.3
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/loss",
|
||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
@@ -312,7 +312,7 @@ class TestMultiGPULlama:
|
|||||||
loss_threshold = 2.3
|
loss_threshold = 2.3
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/loss",
|
||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
@@ -385,7 +385,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -461,7 +461,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_6_0
|
||||||
@@ -543,7 +543,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
@@ -623,7 +623,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -708,7 +708,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.45, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -784,7 +784,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -859,7 +859,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
@@ -925,5 +925,5 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 4.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class TestMultiGPURay:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
@@ -138,7 +138,7 @@ class TestMultiGPURay:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
@@ -205,5 +205,5 @@ class TestMultiGPURay:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,5 +64,5 @@ class TestTensorParallel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -78,5 +78,5 @@ class TestFAXentropyLlama:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -77,5 +77,5 @@ class TestFAFlattening:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import subprocess
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||||
@@ -124,7 +124,7 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -180,5 +180,5 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
|
||||||
|
)
|
||||||
class TestDeepseekV3:
|
class TestDeepseekV3:
|
||||||
"""
|
"""
|
||||||
Test case for DeepseekV3 models
|
Test case for DeepseekV3 models
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_orpo_lora(self, temp_dir):
|
def test_orpo_lora(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -57,9 +57,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_train_w_embedding_lr(self, temp_dir):
|
def test_train_w_embedding_lr(self, temp_dir):
|
||||||
@@ -100,6 +98,4 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class TestPretrainLlama:
|
|||||||
loss_threshold = 6.5
|
loss_threshold = 6.5
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/loss",
|
||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
@@ -282,3 +284,60 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"optimizer_name,expected_class,learning_rate",
|
||||||
|
[
|
||||||
|
("flash_adamw", "FlashAdamW", 0.00001),
|
||||||
|
("flash_adam", "FlashAdam", 0.00001),
|
||||||
|
("flash_sgd", "FlashSGD", 0.01),
|
||||||
|
("flash_sgdw", "FlashSGDW", 0.01),
|
||||||
|
("flash_lion", "FlashLion", 0.0001),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
|
||||||
|
pytest.importorskip("flashoptim")
|
||||||
|
temp_dir = str(tmp_path)
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"optimizer": optimizer_name,
|
||||||
|
"max_steps": 5,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.7, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class TestQATLlama:
|
|||||||
loss_threshold = 2.3
|
loss_threshold = 2.3
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/loss",
|
||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,14 @@ from tests.e2e.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fake_quant_config_dtype(config):
|
||||||
|
"""Get the weight dtype from a fake quantize config, handling different config types."""
|
||||||
|
if hasattr(config, "dtype"):
|
||||||
|
return config.dtype
|
||||||
|
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
|
||||||
|
return torch.int4
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def model():
|
def model():
|
||||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -157,6 +165,18 @@ class TestQuantization:
|
|||||||
expected_exception,
|
expected_exception,
|
||||||
expected_tensor_class,
|
expected_tensor_class,
|
||||||
):
|
):
|
||||||
|
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
|
||||||
|
# (see https://pypi.org/project/mslk-cuda/)
|
||||||
|
if expected_tensor_class is Int4Tensor and activation_dtype is None:
|
||||||
|
try:
|
||||||
|
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
|
||||||
|
int4_row_quantize_zp,
|
||||||
|
)
|
||||||
|
|
||||||
|
if int4_row_quantize_zp is None:
|
||||||
|
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||||
if expected_exception:
|
if expected_exception:
|
||||||
with pytest.raises(expected_exception):
|
with pytest.raises(expected_exception):
|
||||||
quantize_model(
|
quantize_model(
|
||||||
@@ -252,28 +272,24 @@ class TestQuantization:
|
|||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||||
assert (
|
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
|
||||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
|
||||||
== weight_dtype.value
|
|
||||||
)
|
|
||||||
if group_size:
|
if group_size:
|
||||||
assert (
|
assert embed_config.group_size == group_size
|
||||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
|
||||||
== group_size
|
|
||||||
)
|
|
||||||
|
|
||||||
for child in list(model.children()):
|
for child in list(model.children()):
|
||||||
if isinstance(child, torch.nn.Linear):
|
if isinstance(child, torch.nn.Linear):
|
||||||
assert isinstance(child, FakeQuantizedLinear)
|
assert isinstance(child, FakeQuantizedLinear)
|
||||||
assert hasattr(child, "weight_fake_quantizer")
|
assert hasattr(child, "weight_fake_quantizer")
|
||||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
w_config = child.weight_fake_quantizer.config
|
||||||
|
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
|
||||||
if group_size:
|
if group_size:
|
||||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
assert w_config.group_size == group_size
|
||||||
if activation_dtype:
|
if activation_dtype:
|
||||||
assert hasattr(child, "activation_fake_quantizer")
|
assert hasattr(child, "activation_fake_quantizer")
|
||||||
|
a_config = child.activation_fake_quantizer.config
|
||||||
assert (
|
assert (
|
||||||
child.activation_fake_quantizer.config.dtype
|
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
|
||||||
== activation_dtype.value
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert child.activation_fake_quantizer is None
|
assert child.activation_fake_quantizer is None
|
||||||
@@ -374,9 +390,16 @@ class TestQuantizationCallback:
|
|||||||
|
|
||||||
# ensure model has been quantized
|
# ensure model has been quantized
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
|
||||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
# Only test enable/disable toggling if the fake quantizer supports it
|
||||||
|
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
|
||||||
|
supports_toggle = hasattr(
|
||||||
|
model.model.embed_tokens.weight_fake_quantizer, "enabled"
|
||||||
|
)
|
||||||
|
if supports_toggle:
|
||||||
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
qat_callback = QATCallback(cfg)
|
qat_callback = QATCallback(cfg)
|
||||||
|
|
||||||
@@ -388,9 +411,10 @@ class TestQuantizationCallback:
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# quantization should have been disabled
|
if supports_toggle:
|
||||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
# quantization should have been disabled
|
||||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
|
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
trainer_state.global_step = 100
|
trainer_state.global_step = 100
|
||||||
qat_callback.on_step_begin(
|
qat_callback.on_step_begin(
|
||||||
@@ -400,9 +424,10 @@ class TestQuantizationCallback:
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# quantization should have been enabled
|
if supports_toggle:
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
# quantization should have been enabled
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||||
@@ -424,9 +449,10 @@ class TestQuantizationCallback:
|
|||||||
|
|
||||||
# ensure model has been quantized
|
# ensure model has been quantized
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
|
||||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||||
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
qat_callback = QATCallback(cfg)
|
qat_callback = QATCallback(cfg)
|
||||||
# simulate first training step
|
# simulate first training step
|
||||||
@@ -438,5 +464,6 @@ class TestQuantizationCallback:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# quantization should be enabled from the get-go
|
# quantization should be enabled from the get-go
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|||||||
@@ -66,6 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class TestStreamingDatasets:
|
|||||||
# Verify training actually happened by checking loss decrease
|
# Verify training actually happened by checking loss decrease
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/loss",
|
||||||
3.0,
|
3.0,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
|
from axolotl.utils.schemas.datasets import SFTDataset
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
|
|||||||
assert new_cfg.dataloader_num_workers == 8
|
assert new_cfg.dataloader_num_workers == 8
|
||||||
assert new_cfg.dataloader_pin_memory is True
|
assert new_cfg.dataloader_pin_memory is True
|
||||||
assert new_cfg.dataloader_prefetch_factor == 256
|
assert new_cfg.dataloader_prefetch_factor == 256
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyntheticDatasetValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
Tests for synthetic dataset config validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_cfg(minimal_cfg, datasets):
|
||||||
|
raw = dict(minimal_cfg)
|
||||||
|
raw["datasets"] = datasets
|
||||||
|
return DictDefault(raw)
|
||||||
|
|
||||||
|
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
||||||
|
"""Synthetic dataset passed as a raw dict should not raise."""
|
||||||
|
cfg = self._make_cfg(
|
||||||
|
minimal_cfg,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "synthetic",
|
||||||
|
"type": "_synthetic",
|
||||||
|
"length": 100,
|
||||||
|
"sequence_length": 64,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||||
|
|
||||||
|
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
||||||
|
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
||||||
|
sft = SFTDataset(path="synthetic", type="_synthetic")
|
||||||
|
cfg = self._make_cfg(minimal_cfg, [sft])
|
||||||
|
|
||||||
|
# Before the fix, this raised:
|
||||||
|
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||||
|
|
||||||
|
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
||||||
|
"""A regular SFT dataset should validate without being treated as synthetic."""
|
||||||
|
cfg = self._make_cfg(
|
||||||
|
minimal_cfg,
|
||||||
|
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
||||||
|
|||||||
125
tests/prompt_strategies/test_synthetic.py
Normal file
125
tests/prompt_strategies/test_synthetic.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Tests for the synthetic dataset generator."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyntheticDatasetStrategy(unittest.TestCase):
|
||||||
|
def test_generates_correct_shape(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(
|
||||||
|
sequence_length=128,
|
||||||
|
length=50,
|
||||||
|
min_input_id=1,
|
||||||
|
max_input_id=1000,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
assert len(result) == 50
|
||||||
|
assert len(result[0]["input_ids"]) == 128
|
||||||
|
assert len(result[0]["attention_mask"]) == 128
|
||||||
|
assert len(result[0]["labels"]) == 128
|
||||||
|
|
||||||
|
def test_attention_mask_all_ones(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
assert all(v == 1 for v in row["attention_mask"])
|
||||||
|
|
||||||
|
def test_labels_equal_input_ids(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
assert row["input_ids"] == row["labels"]
|
||||||
|
|
||||||
|
def test_input_id_range(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(
|
||||||
|
sequence_length=64,
|
||||||
|
length=100,
|
||||||
|
min_input_id=500,
|
||||||
|
max_input_id=600,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
for token_id in row["input_ids"]:
|
||||||
|
assert 500 <= token_id < 600
|
||||||
|
|
||||||
|
def test_seed_reproducibility(self):
|
||||||
|
kwargs = dict(
|
||||||
|
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
|
||||||
|
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||||
|
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for r1, r2 in zip(result1, result2, strict=True):
|
||||||
|
assert r1["input_ids"] == r2["input_ids"]
|
||||||
|
|
||||||
|
def test_different_seeds_differ(self):
|
||||||
|
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
|
||||||
|
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
|
||||||
|
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
|
||||||
|
|
||||||
|
any_different = any(
|
||||||
|
r1["input_ids"] != r2["input_ids"]
|
||||||
|
for r1, r2 in zip(result1, result2, strict=True)
|
||||||
|
)
|
||||||
|
assert any_different
|
||||||
|
|
||||||
|
def test_load_function_with_ds_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 32000
|
||||||
|
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
|
||||||
|
ds_cfg = {
|
||||||
|
"sequence_length": 256,
|
||||||
|
"length": 5,
|
||||||
|
"min_input_id": 10,
|
||||||
|
"max_input_id": 100,
|
||||||
|
"seed": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
|
||||||
|
assert isinstance(strategy, SyntheticDatasetStrategy)
|
||||||
|
assert strategy.sequence_length == 256
|
||||||
|
assert strategy.length == 5
|
||||||
|
assert strategy.min_input_id == 10
|
||||||
|
assert strategy.max_input_id == 100
|
||||||
|
|
||||||
|
def test_load_defaults_from_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 32000
|
||||||
|
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg, ds_cfg={})
|
||||||
|
assert strategy.sequence_length == 1024
|
||||||
|
assert strategy.max_input_id == 32000
|
||||||
|
assert strategy.length == 1000
|
||||||
|
|
||||||
|
def test_load_with_no_ds_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 50000
|
||||||
|
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg)
|
||||||
|
assert strategy.sequence_length == 2048
|
||||||
|
assert strategy.max_input_id == 50000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for the monkey patch for expand mask to handle packed sequences
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpandMask(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test class for attention mask expansion for packed sequences
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_output(self):
|
|
||||||
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
|
||||||
dtype = torch.float32
|
|
||||||
expected_output = torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
||||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
|
||||||
]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
||||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
]
|
|
||||||
],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Check that the output matches the expected output
|
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user