Compare commits
1 Commits
textui
...
fix/cp-was
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
255c5b90ca |
@@ -128,9 +128,11 @@ quartodoc:
|
|||||||
- monkeypatch.mistral_attn_hijack_flash
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
- monkeypatch.multipack
|
- monkeypatch.multipack
|
||||||
- monkeypatch.relora
|
- monkeypatch.relora
|
||||||
|
- monkeypatch.llama_expand_mask
|
||||||
- monkeypatch.lora_kernels
|
- monkeypatch.lora_kernels
|
||||||
- monkeypatch.utils
|
- monkeypatch.utils
|
||||||
- monkeypatch.btlm_attn_hijack_flash
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_patch_multipack
|
||||||
- monkeypatch.stablelm_attn_hijack_flash
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
- monkeypatch.trainer_fsdp_optim
|
- monkeypatch.trainer_fsdp_optim
|
||||||
- monkeypatch.transformers_fa_utils
|
- monkeypatch.transformers_fa_utils
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
set -o pipefail
|
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
# hf download "microsoft/Phi-4-reasoning"
|
# hf download "microsoft/Phi-4-reasoning"
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ coverage:
|
|||||||
only_pulls: false
|
only_pulls: false
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
informational: true
|
|
||||||
|
|
||||||
parsers:
|
parsers:
|
||||||
gcov:
|
gcov:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
title: Gradient Checkpointing and Activation Offloading
|
||||||
---
|
---
|
||||||
|
|
||||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||||
@@ -27,33 +27,3 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
|
|||||||
|
|
||||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||||
|
|
||||||
### Enabling Layer Offloading
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
layer_offloading: true
|
|
||||||
```
|
|
||||||
|
|
||||||
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
|
||||||
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
|
||||||
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
|
||||||
trainable adapter weights stay on GPU permanently.
|
|
||||||
|
|
||||||
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
|
||||||
|
|
||||||
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
|
||||||
prefetched asynchronously on a separate CUDA stream for overlap.
|
|
||||||
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
|
||||||
previous layer is prefetched.
|
|
||||||
|
|
||||||
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
|
||||||
|
|
||||||
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
|
||||||
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
|
||||||
that is kept on GPU at any given time.
|
|
||||||
|
|
||||||
**Requirements:**
|
|
||||||
|
|
||||||
- CUDA GPU (CPU-only training is not supported for this feature)
|
|
||||||
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
|
||||||
- Best combined with LoRA/QLoRA where most parameters are frozen
|
|
||||||
|
|||||||
@@ -54,13 +54,6 @@ These techniques save VRAM by changing how activations are handled.
|
|||||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||||
|
|
||||||
### Layer Offloading
|
|
||||||
|
|
||||||
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
|
||||||
|
|
||||||
- **Config:** `layer_offloading: true`
|
|
||||||
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
|
||||||
|
|
||||||
### Cut Cross Entropy (CCE)
|
### Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
Note: Training this model requires weights in BF16 which we will link to later.
|
||||||
|
Users interested in training can convert / descale the existing FP8 weights.
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
base_model: Qwen/Qwen3.5-122B-A10B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
chat_template: qwen3_5
|
|
||||||
datasets:
|
|
||||||
- path: mlabonne/FineTome-100k
|
|
||||||
type: chat_template
|
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
# Regex matching to target shared experts too
|
|
||||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
# Target experts
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: true
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -32,11 +32,7 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
# Regex matching to target shared experts too
|
#lora_target_parameters:
|
||||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
# Target experts
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
# - mlp.experts.gate_up_proj
|
||||||
# - mlp.experts.down_proj
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
@@ -56,6 +52,7 @@ learning_rate: 0.0002
|
|||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
lora_mlp_kernel: false
|
||||||
lora_qkv_kernel: false
|
lora_qkv_kernel: false
|
||||||
lora_o_kernel: false
|
lora_o_kernel: false
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
base_model: Qwen/Qwen3.5-27B
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
chat_template: qwen3_5
|
|
||||||
datasets:
|
|
||||||
- path: mlabonne/FineTome-100k
|
|
||||||
type: chat_template
|
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
# Uncomment below to also target the linear attention projections.
|
|
||||||
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
|
||||||
# - linear_attn.in_proj_qkv
|
|
||||||
# - linear_attn.in_proj_z
|
|
||||||
# - linear_attn.out_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
base_model: Qwen/Qwen3.5-27B
|
base_model: Qwen/Qwen3.5-27B
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
|
||||||
|
# the text-only path. For multimodal (image+text) fine-tuning, add image
|
||||||
|
# columns to your dataset following axolotl's multimodal dataset format.
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
base_model: Qwen/Qwen3.5-35B-A3B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
chat_template: qwen3_5
|
|
||||||
datasets:
|
|
||||||
- path: mlabonne/FineTome-100k
|
|
||||||
type: chat_template
|
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Regex matching to target shared experts too
|
|
||||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
# Target experts
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: true
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -32,11 +32,7 @@ lora_target_modules:
|
|||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
# Regex matching to target shared experts too
|
#lora_target_parameters:
|
||||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
# Target experts
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
# - mlp.experts.gate_up_proj
|
||||||
# - mlp.experts.down_proj
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ lora_r: 32
|
|||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
# Targets the language model attention and MLP layers.
|
# Targets the language model attention and MLP layers.
|
||||||
|
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
|
||||||
|
# the same transformer stack, so standard attention targets work for both modalities.
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
|
|||||||
@@ -2,6 +2,20 @@
|
|||||||
|
|
||||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||||
|
|
||||||
|
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
|
||||||
|
|
||||||
|
Available configs:
|
||||||
|
|
||||||
|
| Config | Model | Type | Peak VRAM |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
|
||||||
|
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
|
||||||
|
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||||
|
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||||
|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||||
|
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||||
|
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
@@ -9,69 +23,43 @@
|
|||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||||
```
|
|
||||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
|
||||||
|
|
||||||
4. Pick any config from the table below and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/qwen3.5/<config>.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Available configs:
|
|
||||||
|
|
||||||
| Config | Model | Type | Peak VRAM |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
|
||||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
|
||||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
|
|
||||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
|
|
||||||
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
|
|
||||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
|
||||||
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
|
|
||||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
|
||||||
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
|
|
||||||
|
|
||||||
### Gated DeltaNet Linear Attention
|
|
||||||
|
|
||||||
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
lora_target_modules:
|
|
||||||
# ... standard projections ...
|
|
||||||
- linear_attn.in_proj_qkv
|
|
||||||
- linear_attn.in_proj_z
|
|
||||||
- linear_attn.out_proj
|
|
||||||
```
|
```
|
||||||
|
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||||
|
|
||||||
### Routed Experts (MoE)
|
4. Run a finetuning example:
|
||||||
|
|
||||||
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
|
```bash
|
||||||
|
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||||
|
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||||
|
|
||||||
```yaml
|
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
|
||||||
lora_target_parameters:
|
axolotl train examples/qwen3.5/27b-fft.yaml
|
||||||
- mlp.experts.gate_up_proj
|
|
||||||
- mlp.experts.down_proj
|
|
||||||
# - mlp.gate.weight # router
|
|
||||||
```
|
|
||||||
|
|
||||||
### Shared Experts (MoE)
|
# MoE 35B-A3B text-only (QLoRA)
|
||||||
|
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||||
|
|
||||||
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
# MoE 122B-A10B text-only (QLoRA)
|
||||||
|
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||||
|
|
||||||
|
# 9B vision+text (LoRA, multimodal dataset)
|
||||||
|
axolotl train examples/qwen3.5/9b-lora-vision.yaml
|
||||||
|
|
||||||
|
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
|
||||||
|
axolotl train examples/qwen3.5/9b-fft-vision.yaml
|
||||||
|
|
||||||
```yaml
|
|
||||||
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### TIPS
|
### TIPS
|
||||||
|
|
||||||
- For inference hyp, please see the respective model card details.
|
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||||
|
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
|
||||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||||
|
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
|
|||||||
@@ -61,11 +61,5 @@ skip-magic-trailing-comma = false
|
|||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
docstring-code-format = false
|
docstring-code-format = false
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
addopts = "-m 'not slow'"
|
|
||||||
markers = [
|
|
||||||
"slow: marks tests as slow",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
axolotl = ["huggingface_hub"]
|
||||||
|
|||||||
13
setup.py
13
setup.py
@@ -81,23 +81,16 @@ def parse_requirements(extras_require_map):
|
|||||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (major, minor) >= (2, 10):
|
if (major, minor) >= (2, 9):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
|
||||||
"fbgemm-gpu==1.5.0",
|
|
||||||
"fbgemm-gpu-genai==1.5.0",
|
|
||||||
]
|
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
|
||||||
elif (major, minor) >= (2, 9):
|
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
"fbgemm-gpu==1.4.0",
|
"fbgemm-gpu==1.4.0",
|
||||||
"fbgemm-gpu-genai==1.4.2",
|
"fbgemm-gpu-genai==1.4.2",
|
||||||
]
|
]
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||||
if not install_xformers:
|
if not install_xformers:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpcore
|
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
@@ -48,7 +47,7 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except (HTTPError, httpcore.ConnectError):
|
except HTTPError:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
|||||||
type=click.Path(exists=True, path_type=str),
|
type=click.Path(exists=True, path_type=str),
|
||||||
help="YAML config for sweeping hyperparameters",
|
help="YAML config for sweeping hyperparameters",
|
||||||
)
|
)
|
||||||
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
|
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
@filter_none_kwargs
|
@filter_none_kwargs
|
||||||
@@ -102,7 +101,6 @@ def train(
|
|||||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||||
cloud: str | None = None,
|
cloud: str | None = None,
|
||||||
sweep: str | None = None,
|
sweep: str | None = None,
|
||||||
tui: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -120,10 +118,6 @@ def train(
|
|||||||
# Extract launcher args from extra args (after --)
|
# Extract launcher args from extra args (after --)
|
||||||
launcher_args = ctx.args if ctx.args else []
|
launcher_args = ctx.args if ctx.args else []
|
||||||
|
|
||||||
# Handle --tui flag: set env var so subprocess workers pick it up
|
|
||||||
if tui:
|
|
||||||
os.environ["AXOLOTL_TUI"] = "1"
|
|
||||||
|
|
||||||
# Handle Ray launcher override
|
# Handle Ray launcher override
|
||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -35,101 +34,22 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
# Start TUI early (before data loading) so it captures preprocessing events
|
plugin_manager = PluginManager.get_instance()
|
||||||
tui_renderer = None
|
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||||
tui_queue: queue.Queue | None = None
|
if not dataset_meta:
|
||||||
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
|
if cfg.rl:
|
||||||
if is_rank_0:
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
from axolotl.train import _is_tui_enabled
|
else:
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
if _is_tui_enabled(cfg):
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
import queue as _queue
|
|
||||||
|
|
||||||
from axolotl.train import _get_tui_config
|
del model, tokenizer, trainer
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.renderer import TUIRenderer
|
|
||||||
|
|
||||||
tui_config_dict = _get_tui_config(cfg)
|
gc.collect()
|
||||||
tui_config = (
|
|
||||||
TUIConfig(**tui_config_dict)
|
|
||||||
if isinstance(tui_config_dict, dict)
|
|
||||||
else tui_config_dict
|
|
||||||
)
|
|
||||||
tui_queue = _queue.Queue(maxsize=4096)
|
|
||||||
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
|
|
||||||
|
|
||||||
# Send initial run info
|
plugin_manager = PluginManager.get_instance()
|
||||||
model_name = cfg.base_model or ""
|
plugin_manager.post_train_unload(cfg)
|
||||||
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
try:
|
|
||||||
tui_queue.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "run_info",
|
|
||||||
"model_name": model_name,
|
|
||||||
"training_mode": training_mode,
|
|
||||||
"world_size": world_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except _queue.Full:
|
|
||||||
pass
|
|
||||||
|
|
||||||
tui_renderer.start()
|
|
||||||
|
|
||||||
# Attach logging handler early
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.tui.callback import _TUILogHandler
|
|
||||||
|
|
||||||
_early_log_handler = _TUILogHandler(
|
|
||||||
tui_queue, min_level=tui_config.log_level
|
|
||||||
)
|
|
||||||
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
|
||||||
# Attach to BOTH root and axolotl loggers because axolotl logger
|
|
||||||
# has propagate=False so root handler never sees axolotl.* messages
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.addHandler(_early_log_handler)
|
|
||||||
axolotl_logger = logging.getLogger("axolotl")
|
|
||||||
axolotl_logger.addHandler(_early_log_handler)
|
|
||||||
|
|
||||||
# Stash refs on cfg so train() can reuse the renderer
|
|
||||||
cfg._tui_renderer = tui_renderer
|
|
||||||
cfg._tui_queue = tui_queue
|
|
||||||
cfg._tui_early_log_handler = _early_log_handler
|
|
||||||
|
|
||||||
try:
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
|
||||||
if not dataset_meta:
|
|
||||||
if cfg.rl:
|
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
else:
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
del model, tokenizer, trainer
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
|
||||||
finally:
|
|
||||||
# If the TUI renderer started early but train() didn't get to stop it
|
|
||||||
# (e.g., error during data loading), clean up here
|
|
||||||
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
|
|
||||||
try:
|
|
||||||
if tui_queue is not None:
|
|
||||||
tui_queue.put_nowait({"type": "done"})
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
tui_renderer.stop()
|
|
||||||
# Remove early log handler from both root and axolotl loggers
|
|
||||||
if hasattr(cfg, "_tui_early_log_handler"):
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
|
|
||||||
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
|||||||
@@ -353,30 +353,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
adam_kwargs["eps"] = (eps1, eps2)
|
adam_kwargs["eps"] = (eps1, eps2)
|
||||||
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "flash_adamw":
|
|
||||||
from flashoptim import FlashAdamW
|
|
||||||
|
|
||||||
optimizer_cls = FlashAdamW
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "flash_adam":
|
|
||||||
from flashoptim import FlashAdam
|
|
||||||
|
|
||||||
optimizer_cls = FlashAdam
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "flash_sgd":
|
|
||||||
from flashoptim import FlashSGD
|
|
||||||
|
|
||||||
optimizer_cls = FlashSGD
|
|
||||||
elif self.cfg.optimizer == "flash_sgdw":
|
|
||||||
from flashoptim import FlashSGDW
|
|
||||||
|
|
||||||
optimizer_cls = FlashSGDW
|
|
||||||
elif self.cfg.optimizer == "flash_lion":
|
|
||||||
from flashoptim import FlashLion
|
|
||||||
|
|
||||||
optimizer_cls = FlashLion
|
|
||||||
if "betas" in adam_kwargs:
|
|
||||||
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||||
@@ -508,8 +484,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.layer_offloading:
|
|
||||||
training_args_kwargs["layer_offloading"] = True
|
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
# don't use the HF gradient checkpointing, manually wrap
|
# don't use the HF gradient checkpointing, manually wrap
|
||||||
training_args_kwargs["gradient_checkpointing"] = False
|
training_args_kwargs["gradient_checkpointing"] = False
|
||||||
|
|||||||
@@ -208,11 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if (
|
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||||
self.cfg.adapter
|
|
||||||
and self.peft_config
|
|
||||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
|
||||||
):
|
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
|
|||||||
@@ -29,12 +29,10 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
|||||||
from trl.experimental.utils import pad_to_length
|
from trl.experimental.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
DistributedParallelMixin,
|
DistributedParallelMixin,
|
||||||
LayerOffloadingMixin,
|
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
PackingMixin,
|
PackingMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
@@ -53,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TOKENS_STATE_FILE = "tokens_state."
|
||||||
|
|
||||||
REDUCTION_FNS = {
|
REDUCTION_FNS = {
|
||||||
"mean": torch.mean,
|
"mean": torch.mean,
|
||||||
"min": torch.min,
|
"min": torch.min,
|
||||||
@@ -67,7 +67,6 @@ class AxolotlTrainer(
|
|||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
LayerOffloadingMixin,
|
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
DistributedParallelMixin,
|
DistributedParallelMixin,
|
||||||
Trainer,
|
Trainer,
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
TOKENS_STATE_FILE = "tokens_state.json"
|
|
||||||
@@ -2,8 +2,7 @@
|
|||||||
Axolotl specific DPO args
|
Axolotl specific DPO args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
|
|
||||||
@@ -17,4 +16,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_norm_loss: bool | None = False
|
dpo_norm_loss: bool | None = False
|
||||||
rpo_alpha: Optional[float] = field(default=None)
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
from .activation_checkpointing import ActivationOffloadingMixin
|
from .activation_checkpointing import ActivationOffloadingMixin
|
||||||
from .checkpoints import CheckpointSaveMixin
|
from .checkpoints import CheckpointSaveMixin
|
||||||
from .layer_offloading import LayerOffloadingMixin
|
|
||||||
from .distributed_parallel import DistributedParallelMixin
|
from .distributed_parallel import DistributedParallelMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .packing import PackingMixin
|
from .packing import PackingMixin
|
||||||
|
|||||||
@@ -1,304 +0,0 @@
|
|||||||
"""
|
|
||||||
Trainer mixin for layer-wise parameter offloading to CPU.
|
|
||||||
|
|
||||||
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
|
|
||||||
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
|
|
||||||
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
|
|
||||||
|
|
||||||
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
|
|
||||||
transfer stream), post-hook offloads layer N-1's frozen params.
|
|
||||||
Backward: same in reverse order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
|
|
||||||
"""Recursively search the model for the decoder layer ModuleList.
|
|
||||||
|
|
||||||
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
|
|
||||||
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
|
|
||||||
where layers are at model.language_model.layers).
|
|
||||||
"""
|
|
||||||
# BFS to find the first ModuleList containing decoder layers
|
|
||||||
queue = [model]
|
|
||||||
while queue:
|
|
||||||
m = queue.pop(0)
|
|
||||||
for _name, child in m.named_children():
|
|
||||||
if isinstance(child, nn.ModuleList) and len(child) > 0:
|
|
||||||
first_type = type(child[0]).__name__
|
|
||||||
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
|
|
||||||
layer_types = list({type(layer).__name__ for layer in child})
|
|
||||||
return child, layer_types
|
|
||||||
else:
|
|
||||||
queue.append(child)
|
|
||||||
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
|
|
||||||
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
|
|
||||||
"""Get all non-trainable parameters in a layer."""
|
|
||||||
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
|
|
||||||
|
|
||||||
|
|
||||||
class LayerOffloadManager:
|
|
||||||
"""Manages offloading frozen decoder layer params to CPU and streaming
|
|
||||||
them back during forward/backward with CUDA stream overlap.
|
|
||||||
|
|
||||||
Only frozen (requires_grad=False) parameters are offloaded.
|
|
||||||
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
num_prefetch: int = 1,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.num_prefetch = num_prefetch
|
|
||||||
self._hooks: list = []
|
|
||||||
self._device = None
|
|
||||||
|
|
||||||
# Find decoder layers
|
|
||||||
self.layers, layer_types = _find_decoder_layers(model)
|
|
||||||
if self.layers is None:
|
|
||||||
LOG.warning(
|
|
||||||
"LayerOffloadManager: no decoder layers found, offloading disabled"
|
|
||||||
)
|
|
||||||
self.enabled = False
|
|
||||||
return
|
|
||||||
|
|
||||||
self.enabled = True
|
|
||||||
self.n_layers = len(self.layers)
|
|
||||||
LOG.info(
|
|
||||||
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine GPU device
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.device.type == "cuda":
|
|
||||||
self._device = p.device
|
|
||||||
break
|
|
||||||
if self._device is None:
|
|
||||||
LOG.warning("LayerOffloadManager: no CUDA parameters found")
|
|
||||||
self.enabled = False
|
|
||||||
return
|
|
||||||
|
|
||||||
# Transfer stream for async prefetch
|
|
||||||
self._transfer_stream = torch.cuda.Stream(device=self._device)
|
|
||||||
|
|
||||||
# Track which layers have their frozen params on GPU
|
|
||||||
self._on_gpu: set[int] = set(range(self.n_layers))
|
|
||||||
|
|
||||||
# Cache: frozen param references per layer (list of (name, param) tuples)
|
|
||||||
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
|
|
||||||
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# CPU storage: pinned tensors for each layer's frozen params
|
|
||||||
# Populated on first offload
|
|
||||||
self._cpu_data: list[dict[str, torch.Tensor]] = [
|
|
||||||
{} for _ in range(self.n_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Offload all layers upfront
|
|
||||||
self._offload_all()
|
|
||||||
|
|
||||||
# Release cached memory blocks back to the driver
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def _offload_all(self):
|
|
||||||
"""Move all frozen params in all decoder layers to CPU."""
|
|
||||||
mem_before = torch.cuda.memory_allocated(self._device)
|
|
||||||
for i in range(self.n_layers):
|
|
||||||
self._offload_layer(i)
|
|
||||||
mem_after = torch.cuda.memory_allocated(self._device)
|
|
||||||
freed = (mem_before - mem_after) / 1e6
|
|
||||||
LOG.info(
|
|
||||||
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
|
|
||||||
f"freed {freed:.0f} MB GPU memory"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _offload_layer(self, idx: int):
|
|
||||||
"""Move frozen params of layer idx to CPU pinned memory."""
|
|
||||||
if idx not in self._on_gpu:
|
|
||||||
return
|
|
||||||
for name, param in self._frozen_params[idx]:
|
|
||||||
if param.device.type != "cuda":
|
|
||||||
continue
|
|
||||||
# Allocate pinned CPU tensor on first offload
|
|
||||||
if name not in self._cpu_data[idx]:
|
|
||||||
self._cpu_data[idx][name] = torch.empty_like(
|
|
||||||
param.data, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
cpu_buf = self._cpu_data[idx][name]
|
|
||||||
# Async copy GPU -> CPU (on transfer stream for overlap)
|
|
||||||
cpu_buf.copy_(param.data, non_blocking=True)
|
|
||||||
# Point parameter at a dummy CPU tensor to free GPU memory
|
|
||||||
param.data = cpu_buf
|
|
||||||
self._on_gpu.discard(idx)
|
|
||||||
|
|
||||||
def _load_layer(self, idx: int, stream=None):
|
|
||||||
"""Move frozen params of layer idx back to GPU."""
|
|
||||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
|
||||||
return
|
|
||||||
ctx = (
|
|
||||||
torch.cuda.stream(stream)
|
|
||||||
if stream is not None
|
|
||||||
else contextlib.nullcontext()
|
|
||||||
)
|
|
||||||
with ctx:
|
|
||||||
for _name, param in self._frozen_params[idx]:
|
|
||||||
if param.device.type == "cuda":
|
|
||||||
continue
|
|
||||||
gpu_data = param.data.to(self._device, non_blocking=True)
|
|
||||||
param.data = gpu_data
|
|
||||||
self._on_gpu.add(idx)
|
|
||||||
|
|
||||||
def _prefetch_layer(self, idx: int):
|
|
||||||
"""Async prefetch layer idx on the transfer stream."""
|
|
||||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
|
||||||
return
|
|
||||||
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
|
|
||||||
self._load_layer(idx, stream=self._transfer_stream)
|
|
||||||
|
|
||||||
def _wait_transfer(self):
|
|
||||||
"""Make default stream wait for any in-flight transfers."""
|
|
||||||
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
|
|
||||||
|
|
||||||
def setup_hooks(self):
|
|
||||||
"""Register forward and backward hooks on each decoder layer."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
for idx in range(self.n_layers):
|
|
||||||
layer = self.layers[idx]
|
|
||||||
|
|
||||||
def make_pre_fwd(i):
|
|
||||||
def hook(module, args):
|
|
||||||
# Ensure this layer is on GPU
|
|
||||||
if i not in self._on_gpu:
|
|
||||||
self._load_layer(i)
|
|
||||||
self._wait_transfer()
|
|
||||||
# Prefetch next layer(s)
|
|
||||||
for offset in range(1, self.num_prefetch + 1):
|
|
||||||
self._prefetch_layer(i + offset)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def make_post_fwd(i):
|
|
||||||
def hook(module, args, output):
|
|
||||||
# Offload previous layer (no longer needed in forward)
|
|
||||||
if i > 0:
|
|
||||||
self._offload_layer(i - 1)
|
|
||||||
# Offload last layer after forward
|
|
||||||
if i == self.n_layers - 1:
|
|
||||||
self._offload_layer(i)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def make_pre_bwd(i):
|
|
||||||
def hook(module, grad_output):
|
|
||||||
# Load this layer for backward
|
|
||||||
if i not in self._on_gpu:
|
|
||||||
self._load_layer(i)
|
|
||||||
self._wait_transfer()
|
|
||||||
# Prefetch previous layer(s)
|
|
||||||
for offset in range(1, self.num_prefetch + 1):
|
|
||||||
self._prefetch_layer(i - offset)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def make_post_bwd(i):
|
|
||||||
def hook(module, grad_input, grad_output):
|
|
||||||
# Offload the layer above
|
|
||||||
if i < self.n_layers - 1:
|
|
||||||
self._offload_layer(i + 1)
|
|
||||||
# Offload first layer after backward
|
|
||||||
if i == 0:
|
|
||||||
self._offload_layer(i)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
|
|
||||||
h2 = layer.register_forward_hook(make_post_fwd(idx))
|
|
||||||
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
|
|
||||||
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
|
|
||||||
self._hooks.extend([h1, h2, h3, h4])
|
|
||||||
|
|
||||||
def remove_hooks(self):
|
|
||||||
"""Remove all hooks and restore layers to GPU."""
|
|
||||||
for h in self._hooks:
|
|
||||||
h.remove()
|
|
||||||
self._hooks.clear()
|
|
||||||
if self.enabled:
|
|
||||||
for i in range(self.n_layers):
|
|
||||||
if i not in self._on_gpu:
|
|
||||||
self._load_layer(i)
|
|
||||||
|
|
||||||
def pre_step(self):
|
|
||||||
"""Called before each training step — ensure layers start offloaded."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
for i in list(self._on_gpu):
|
|
||||||
self._offload_layer(i)
|
|
||||||
# Prefetch layer 0 for forward
|
|
||||||
self._prefetch_layer(0)
|
|
||||||
|
|
||||||
def post_step(self):
|
|
||||||
"""Called after each training step — ensure layers are offloaded."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
for i in list(self._on_gpu):
|
|
||||||
self._offload_layer(i)
|
|
||||||
# Prefetch layer 0 for next step
|
|
||||||
self._prefetch_layer(0)
|
|
||||||
|
|
||||||
|
|
||||||
class _LayerOffloadContext:
|
|
||||||
"""Context manager wrapping pre_step / post_step around a training step."""
|
|
||||||
|
|
||||||
def __init__(self, manager: LayerOffloadManager):
|
|
||||||
self.manager = manager
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.manager.pre_step()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.manager.post_step()
|
|
||||||
|
|
||||||
|
|
||||||
class LayerOffloadingMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Trainer mixin class for layer-wise parameter offloading to CPU.
|
|
||||||
|
|
||||||
Offloads frozen decoder layer params to CPU at init, then streams them
|
|
||||||
on/off GPU one layer at a time during each training step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
if getattr(self.args, "layer_offloading", False):
|
|
||||||
LOG.info("Layer parameter offloading enabled")
|
|
||||||
self._layer_offload_manager = LayerOffloadManager(
|
|
||||||
model=self.model,
|
|
||||||
num_prefetch=1,
|
|
||||||
)
|
|
||||||
self._layer_offload_manager.setup_hooks()
|
|
||||||
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
|
|
||||||
else:
|
|
||||||
self._layer_offload_manager = None
|
|
||||||
self._layer_offload_ctx = contextlib.nullcontext()
|
|
||||||
|
|
||||||
def training_step(self, *args, **kwargs):
|
|
||||||
with self._layer_offload_ctx:
|
|
||||||
return super().training_step(*args, **kwargs)
|
|
||||||
@@ -235,13 +235,6 @@ class AxolotlTrainingMixins:
|
|||||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_offloading: bool | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-modal section
|
# multi-modal section
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ SPARSE_MOE_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||||
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
|
|
||||||
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||||
@@ -59,16 +58,7 @@ def resolve_moe_block_classes(model_type: str):
|
|||||||
|
|
||||||
cls_names = entry if isinstance(entry, list) else [entry]
|
cls_names = entry if isinstance(entry, list) else [entry]
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
try:
|
module = importlib.import_module(module_path)
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
|
|
||||||
if model_type.endswith("_text"):
|
|
||||||
parent_type = model_type.removesuffix("_text")
|
|
||||||
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
classes = []
|
classes = []
|
||||||
for cls_name in cls_names:
|
for cls_name in cls_names:
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128}
|
BLOCK_M: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64}
|
BLOCK_N: {32, 64, 128, 256}
|
||||||
BLOCK_K: {32, 64, 128}
|
BLOCK_K: {32, 64, 128}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
@@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
configs = []
|
configs = []
|
||||||
for block_m, block_n, block_k, warps, stages in product(
|
for block_m, block_n, block_k, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64], # BLOCK_N
|
[32, 64, 128, 256], # BLOCK_N
|
||||||
[32, 64, 128], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
@@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128} (token tile)
|
BLOCK_M: {32, 64, 128} (token tile)
|
||||||
BLOCK_K: {32, 64, 128} (output tile)
|
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||||
BLOCK_N: {32, 64} (reduction tile)
|
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128], # BLOCK_K (output dimension)
|
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||||
[32, 64], # BLOCK_N (reduction dimension)
|
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
@@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs():
|
|||||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128} (token-loop tile)
|
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
||||||
BLOCK_K: {32, 64, 128}
|
BLOCK_K: {32, 64, 128, 256}
|
||||||
BLOCK_N: {32, 64}
|
BLOCK_N: {32, 64, 128, 256}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
|
|
||||||
@@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs():
|
|||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128, 256], # BLOCK_M
|
||||||
[32, 64, 128], # BLOCK_K
|
[32, 64, 128, 256], # BLOCK_K
|
||||||
[32, 64], # BLOCK_N
|
[32, 64, 128, 256], # BLOCK_N
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -30,15 +30,6 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: bool | None = None
|
liger_rope: bool | None = None
|
||||||
liger_rms_norm: bool | None = None
|
liger_rms_norm: bool | None = None
|
||||||
liger_rms_norm_gated: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": (
|
|
||||||
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
|
||||||
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
liger_layer_norm: bool | None = None
|
liger_layer_norm: bool | None = None
|
||||||
liger_swiglu: bool | None = None
|
liger_swiglu: bool | None = None
|
||||||
liger_glu_activation: bool | None = None
|
liger_glu_activation: bool | None = None
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Cache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> CausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3_5(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
rms_norm_gated: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
|
||||||
|
|
||||||
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
|
||||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
|
||||||
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
||||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
|
||||||
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
|
||||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
|
||||||
super().__init__(
|
|
||||||
dim,
|
|
||||||
eps=eps,
|
|
||||||
offset=1.0,
|
|
||||||
casting_mode="gemma",
|
|
||||||
init_fn="zeros",
|
|
||||||
in_place=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
|
||||||
|
|
||||||
if rms_norm_gated:
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
|
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size is not None:
|
|
||||||
config.intermediate_size = intermediate_size
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
self.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3_5_moe(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
rms_norm_gated: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
|
||||||
|
|
||||||
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
|
||||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
|
||||||
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
||||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
|
||||||
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
|
||||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
|
||||||
super().__init__(
|
|
||||||
dim,
|
|
||||||
eps=eps,
|
|
||||||
offset=1.0,
|
|
||||||
casting_mode="gemma",
|
|
||||||
init_fn="zeros",
|
|
||||||
in_place=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
|
||||||
|
|
||||||
if rms_norm_gated:
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
|
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size is not None:
|
|
||||||
config.intermediate_size = intermediate_size
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
|
||||||
@@ -174,19 +174,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type == "qwen3_5":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_5 import (
|
|
||||||
apply_liger_kernel_to_qwen3_5,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_5(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
@@ -199,19 +186,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type == "qwen3_5_moe":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
|
||||||
apply_liger_kernel_to_qwen3_5_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_5_moe(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "granitemoe":
|
elif cfg.model_config_type == "granitemoe":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||||
|
|
||||||
|
|||||||
@@ -1,147 +0,0 @@
|
|||||||
"""
|
|
||||||
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
|
||||||
|
|
||||||
Fuses the weight norm computation and magnitude scaling to avoid
|
|
||||||
materializing the full [out_features, in_features] combined weight matrix.
|
|
||||||
The B@A product is computed row-by-row inside the kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from .quantize import dequantize
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _dora_fused_norm_kernel(
|
|
||||||
# Pointers
|
|
||||||
W_ptr, # base weight [out, in] (dequantized, row-major)
|
|
||||||
B_ptr, # LoRA B [out, rank] (row-major)
|
|
||||||
A_ptr, # LoRA A [rank, in] (row-major)
|
|
||||||
mag_ptr, # magnitude vector [out]
|
|
||||||
out_ptr, # output mag_norm_scale [out]
|
|
||||||
# Shapes
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
rank,
|
|
||||||
# Scaling
|
|
||||||
lora_scale, # float scaling factor
|
|
||||||
# Block sizes
|
|
||||||
BLOCK_IN: tl.constexpr,
|
|
||||||
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
|
||||||
):
|
|
||||||
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
|
||||||
|
|
||||||
Each program handles one output row. B[row,:] is loaded once (small),
|
|
||||||
then we tile over in_features computing the dot product with A[:,tile]
|
|
||||||
and accumulating the squared norm.
|
|
||||||
|
|
||||||
This avoids materializing the full [out, in] B@A matrix.
|
|
||||||
"""
|
|
||||||
row = tl.program_id(0)
|
|
||||||
if row >= out_features:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Accumulate squared norm across tiles of in_features
|
|
||||||
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
|
||||||
|
|
||||||
for start in range(0, in_features, BLOCK_IN):
|
|
||||||
cols = start + tl.arange(0, BLOCK_IN)
|
|
||||||
col_mask = cols < in_features
|
|
||||||
|
|
||||||
# Load W[row, cols]
|
|
||||||
w_vals = tl.load(
|
|
||||||
W_ptr + row * in_features + cols,
|
|
||||||
mask=col_mask,
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
|
|
||||||
# Compute (B[row,:] @ A[:, cols]) for this tile
|
|
||||||
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
|
||||||
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
|
||||||
for r in tl.static_range(BLOCK_R):
|
|
||||||
# Load scalar B[row, r]
|
|
||||||
b_val = tl.load(
|
|
||||||
B_ptr + row * rank + r,
|
|
||||||
mask=(r < rank),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
# Load vector A[r, cols]
|
|
||||||
a_vals = tl.load(
|
|
||||||
A_ptr + r * in_features + cols,
|
|
||||||
mask=(col_mask & (r < rank)),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
ba_vals += b_val * a_vals
|
|
||||||
|
|
||||||
# Combined: W + s * (B @ A)
|
|
||||||
combined = w_vals + lora_scale * ba_vals
|
|
||||||
|
|
||||||
# Accumulate squared values
|
|
||||||
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
|
||||||
|
|
||||||
# Reduce to scalar norm
|
|
||||||
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
|
||||||
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
|
||||||
|
|
||||||
# Load magnitude and compute scale
|
|
||||||
mag = tl.load(mag_ptr + row).to(tl.float32)
|
|
||||||
scale = mag / norm
|
|
||||||
|
|
||||||
tl.store(out_ptr + row, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def triton_dora_scale(
|
|
||||||
W: torch.Tensor,
|
|
||||||
W_quant,
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
s: float,
|
|
||||||
magnitude: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
|
||||||
|
|
||||||
Computes B@A row-by-row inside the kernel, avoiding the full
|
|
||||||
[out_features, in_features] materialization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
W: base weight [out, in] (possibly quantized)
|
|
||||||
W_quant: quantization state
|
|
||||||
A: LoRA A [rank, in]
|
|
||||||
B: LoRA B [out, rank]
|
|
||||||
s: LoRA scaling factor
|
|
||||||
magnitude: learned magnitude [out]
|
|
||||||
dtype: compute dtype
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
|
||||||
"""
|
|
||||||
# Dequantize W to [out, in]
|
|
||||||
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
|
||||||
|
|
||||||
out_features, in_features = W_full.shape
|
|
||||||
rank = A.shape[0]
|
|
||||||
|
|
||||||
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
|
||||||
|
|
||||||
# Block sizes
|
|
||||||
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
|
||||||
BLOCK_R = triton.next_power_of_2(rank)
|
|
||||||
|
|
||||||
_dora_fused_norm_kernel[(out_features,)](
|
|
||||||
W_full,
|
|
||||||
B.contiguous().to(dtype),
|
|
||||||
A.contiguous().to(dtype),
|
|
||||||
magnitude.contiguous(),
|
|
||||||
out,
|
|
||||||
out_features=out_features,
|
|
||||||
in_features=in_features,
|
|
||||||
rank=rank,
|
|
||||||
lora_scale=s,
|
|
||||||
BLOCK_IN=BLOCK_IN,
|
|
||||||
BLOCK_R=BLOCK_R,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out.detach()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -105,10 +105,6 @@ def dequantize(
|
|||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
# Non-double-quantized models have offset=None and state2=None
|
|
||||||
if quant_state.offset is None or quant_state.state2 is None:
|
|
||||||
# Fall back to bitsandbytes standard dequantize
|
|
||||||
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
|
|||||||
@@ -1,333 +0,0 @@
|
|||||||
"""
|
|
||||||
Fused RMSNorm + SiLU Gate Triton kernel.
|
|
||||||
|
|
||||||
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
|
||||||
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
|
||||||
and silu(G) = G * sigmoid(G)
|
|
||||||
|
|
||||||
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from liger_kernel.ops.utils import (
|
|
||||||
calculate_settings,
|
|
||||||
compare_version,
|
|
||||||
ensure_contiguous,
|
|
||||||
torch_to_triton_dtype,
|
|
||||||
)
|
|
||||||
from liger_kernel.utils import is_npu_available
|
|
||||||
|
|
||||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
||||||
try:
|
|
||||||
from triton.language.extra.libdevice import rsqrt
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from triton.language.extra.cuda.libdevice import rsqrt
|
|
||||||
else:
|
|
||||||
from triton.language.math import rsqrt
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rms_norm_gated_forward_kernel(
|
|
||||||
Y_ptr,
|
|
||||||
Y_row_stride,
|
|
||||||
X_ptr,
|
|
||||||
X_row_stride,
|
|
||||||
G_ptr,
|
|
||||||
G_row_stride,
|
|
||||||
W_ptr,
|
|
||||||
W_row_stride,
|
|
||||||
RSTD_ptr,
|
|
||||||
RSTD_row_stride,
|
|
||||||
n_cols,
|
|
||||||
eps,
|
|
||||||
offset,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
|
||||||
|
|
||||||
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0).to(tl.int64)
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
|
||||||
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
|
||||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
||||||
|
|
||||||
X_row_dtype = X_row.dtype
|
|
||||||
|
|
||||||
# Cast everything to fp32
|
|
||||||
X_fp32 = X_row.to(tl.float32)
|
|
||||||
G_fp32 = G_row.to(tl.float32)
|
|
||||||
W_fp32 = W_row.to(tl.float32)
|
|
||||||
|
|
||||||
# RMS norm
|
|
||||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
|
||||||
rstd = rsqrt(mean_sq + eps)
|
|
||||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
|
||||||
|
|
||||||
X_norm = X_fp32 * rstd
|
|
||||||
|
|
||||||
# SiLU gate: silu(G) = G * sigmoid(G)
|
|
||||||
sig_G = tl.sigmoid(G_fp32)
|
|
||||||
silu_G = G_fp32 * sig_G
|
|
||||||
|
|
||||||
# Fused output
|
|
||||||
Y_row = (offset + W_fp32) * X_norm * silu_G
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
|
||||||
Y_row.to(X_row_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rms_norm_gated_backward_kernel(
|
|
||||||
dY_ptr,
|
|
||||||
dY_row_stride,
|
|
||||||
dX_ptr,
|
|
||||||
dX_row_stride,
|
|
||||||
dG_ptr,
|
|
||||||
dG_row_stride,
|
|
||||||
X_ptr,
|
|
||||||
X_row_stride,
|
|
||||||
X_dtype: tl.constexpr,
|
|
||||||
G_ptr,
|
|
||||||
G_row_stride,
|
|
||||||
W_ptr,
|
|
||||||
W_row_stride,
|
|
||||||
RSTD_ptr,
|
|
||||||
RSTD_row_stride,
|
|
||||||
dW_ptr,
|
|
||||||
dW_row_stride,
|
|
||||||
n_rows,
|
|
||||||
n_cols,
|
|
||||||
offset,
|
|
||||||
rows_per_program,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
|
||||||
|
|
||||||
dW = sum_batch(dY * X_norm * silu(G))
|
|
||||||
dG = dY * (W + offset) * X_norm * silu'(G)
|
|
||||||
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
|
||||||
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
|
||||||
where m = dY * (W + offset) * silu(G)
|
|
||||||
"""
|
|
||||||
row_block_id = tl.program_id(0).to(tl.int64)
|
|
||||||
row_start = row_block_id * rows_per_program
|
|
||||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
||||||
|
|
||||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
||||||
W_row = W_row.to(tl.float32) + offset
|
|
||||||
|
|
||||||
for row_idx in range(row_start, row_end):
|
|
||||||
dY_row = tl.load(
|
|
||||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
X_row = tl.load(
|
|
||||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
G_row = tl.load(
|
|
||||||
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
|
||||||
|
|
||||||
# Cast to fp32
|
|
||||||
dY_fp32 = dY_row.to(tl.float32)
|
|
||||||
X_fp32 = X_row.to(tl.float32)
|
|
||||||
G_fp32 = G_row.to(tl.float32)
|
|
||||||
|
|
||||||
# Recompute intermediates
|
|
||||||
X_norm = X_fp32 * rstd_row
|
|
||||||
sig_G = tl.sigmoid(G_fp32)
|
|
||||||
silu_G = G_fp32 * sig_G
|
|
||||||
|
|
||||||
# dW: accumulate dY * X_norm * silu(G)
|
|
||||||
dW_acc += dY_fp32 * X_norm * silu_G
|
|
||||||
|
|
||||||
# dG: dY * (W + offset) * X_norm * silu'(G)
|
|
||||||
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
|
||||||
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
|
||||||
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
|
||||||
tl.store(
|
|
||||||
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
|
||||||
dG_row.to(X_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
|
||||||
m = dY_fp32 * W_row * silu_G
|
|
||||||
dX_row = rstd_row * m
|
|
||||||
dX_row += rstd_row * (
|
|
||||||
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
|
||||||
dX_row.to(X_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
|
||||||
dW_acc,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_gated_forward(X, G, W, eps, offset):
|
|
||||||
shape = X.shape
|
|
||||||
dim = shape[-1]
|
|
||||||
X = X.view(-1, dim)
|
|
||||||
G = G.view(-1, dim)
|
|
||||||
n_rows, n_cols = X.shape
|
|
||||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
||||||
|
|
||||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
||||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
|
||||||
|
|
||||||
assert X.shape[1] == W.shape[0], (
|
|
||||||
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
|
||||||
)
|
|
||||||
assert X.shape == G.shape, (
|
|
||||||
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_rms_norm_gated_forward_kernel[(n_rows,)](
|
|
||||||
Y,
|
|
||||||
Y.stride(0),
|
|
||||||
X,
|
|
||||||
X.stride(0),
|
|
||||||
G,
|
|
||||||
G.stride(0),
|
|
||||||
W,
|
|
||||||
W.stride(0),
|
|
||||||
RSTD,
|
|
||||||
RSTD.stride(0),
|
|
||||||
n_cols,
|
|
||||||
eps,
|
|
||||||
offset,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
num_warps=num_warps,
|
|
||||||
)
|
|
||||||
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
|
||||||
shape = dY.shape
|
|
||||||
dim = shape[-1]
|
|
||||||
dY = dY.view(-1, dim)
|
|
||||||
n_rows, n_cols = dY.shape
|
|
||||||
|
|
||||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
||||||
|
|
||||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
||||||
dX = torch.empty_like(dY)
|
|
||||||
dG = torch.empty_like(dY)
|
|
||||||
|
|
||||||
rows_per_program = math.ceil(n_rows / sm_count)
|
|
||||||
grid = (sm_count,)
|
|
||||||
|
|
||||||
_rms_norm_gated_backward_kernel[grid](
|
|
||||||
dY,
|
|
||||||
dY.stride(0),
|
|
||||||
dX,
|
|
||||||
dX.stride(0),
|
|
||||||
dG,
|
|
||||||
dG.stride(0),
|
|
||||||
X,
|
|
||||||
X.stride(0),
|
|
||||||
torch_to_triton_dtype[X.dtype],
|
|
||||||
G,
|
|
||||||
G.stride(0),
|
|
||||||
W,
|
|
||||||
W.stride(0),
|
|
||||||
RSTD,
|
|
||||||
RSTD.stride(0),
|
|
||||||
_dW,
|
|
||||||
_dW.stride(0),
|
|
||||||
n_rows,
|
|
||||||
n_cols,
|
|
||||||
offset,
|
|
||||||
rows_per_program,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
num_warps=num_warps,
|
|
||||||
)
|
|
||||||
|
|
||||||
dX = dX.view(*shape)
|
|
||||||
dG = dG.view(*shape)
|
|
||||||
dW = _dW.sum(dim=0).to(W.dtype)
|
|
||||||
return dX, dG, dW
|
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
@ensure_contiguous
|
|
||||||
def forward(ctx, X, G, W, eps, offset=0.0):
|
|
||||||
"""
|
|
||||||
X: (B, T, H) or (BxT, H) — input hidden states
|
|
||||||
G: (B, T, H) or (BxT, H) — gate tensor
|
|
||||||
W: (H,) — weight parameter
|
|
||||||
"""
|
|
||||||
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
|
||||||
X, G, W, eps, offset
|
|
||||||
)
|
|
||||||
ctx.offset = offset
|
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
||||||
ctx.num_warps = num_warps
|
|
||||||
ctx.save_for_backward(X, G, W, RSTD)
|
|
||||||
return Y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@ensure_contiguous
|
|
||||||
def backward(ctx, dY):
|
|
||||||
X, G, W, RSTD = ctx.saved_tensors
|
|
||||||
dX, dG, dW = rms_norm_gated_backward(
|
|
||||||
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
|
||||||
)
|
|
||||||
return dX, dG, dW, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNormGated(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused RMSNorm + SiLU Gate.
|
|
||||||
|
|
||||||
Computes: Y = W * RMSNorm(X) * silu(G)
|
|
||||||
|
|
||||||
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
|
||||||
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
|
||||||
and forward signature: forward(hidden_states, gate=None)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
self.offset = offset
|
|
||||||
|
|
||||||
def forward(self, hidden_states, gate=None):
|
|
||||||
if gate is None:
|
|
||||||
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
|
||||||
if hidden_states.device.type != "cuda":
|
|
||||||
raise ValueError(
|
|
||||||
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
|
||||||
)
|
|
||||||
return FusedRMSNormGatedFunction.apply(
|
|
||||||
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
||||||
@@ -133,13 +133,6 @@ class PatchManager:
|
|||||||
patch_evaluation_loop()
|
patch_evaluation_loop()
|
||||||
patch_maybe_log_save_evaluate()
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
if self.cfg.context_parallel_size > 1:
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
|
||||||
patch_prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
|
|
||||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches right after model build, before post-load setup."""
|
"""Apply patches right after model build, before post-load setup."""
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
@@ -571,6 +564,15 @@ class PatchManager:
|
|||||||
LOG.info("Patching with xformers attention...")
|
LOG.info("Patching with xformers attention...")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
|
|
||||||
|
def _patch_llama_sample_packing(self):
|
||||||
|
"""Apply sample packing patches for LLaMA models."""
|
||||||
|
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||||
|
hijack_llama_prepare_4d_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
|
||||||
|
hijack_llama_prepare_4d_mask()
|
||||||
|
|
||||||
def _patch_llama_derived_model(self):
|
def _patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block."""
|
"""Modify all llama derived models in one block."""
|
||||||
if self.cfg.is_llama_derived_model and not (
|
if self.cfg.is_llama_derived_model and not (
|
||||||
@@ -582,6 +584,8 @@ class PatchManager:
|
|||||||
self._patch_llama_flash_attention()
|
self._patch_llama_flash_attention()
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
|
elif self.cfg.sample_packing:
|
||||||
|
self._patch_llama_sample_packing()
|
||||||
elif self.cfg.s2_attention:
|
elif self.cfg.s2_attention:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
|
|||||||
@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||||
|
|
||||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
|
||||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
LOG.warning(
|
|
||||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
|
||||||
tokenizer.eos_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
additional_special_tokens = None
|
additional_special_tokens = None
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
special_tokens = cfg.special_tokens.to_dict()
|
special_tokens = cfg.special_tokens.to_dict()
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ def patch_prepare_cp():
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
def patched_prepare_cp(self, *args):
|
def patched_prepare_cp(self, *args):
|
||||||
if self.parallelism_config.cp_backend == "deepspeed":
|
if self.parallelism_config.cp_backend == "deepspeed":
|
||||||
@@ -95,4 +96,11 @@ def patch_prepare_cp():
|
|||||||
self._cp_context = _noop_cp_context
|
self._cp_context = _noop_cp_context
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def _noop_prepare_context_parallel_inputs(self, model, inputs):
|
||||||
|
return contextlib.nullcontext, inputs
|
||||||
|
|
||||||
|
# prevent double CP partition
|
||||||
Accelerator._prepare_cp = patched_prepare_cp
|
Accelerator._prepare_cp = patched_prepare_cp
|
||||||
|
|
||||||
|
# remove unneeded calculation upstream
|
||||||
|
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
|
||||||
|
|||||||
24
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
24
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||||
|
inverted_mask = 1.0 - masked_zero_one_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_expand_mask():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama._expand_mask = _expand_mask
|
||||||
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import (
|
||||||
|
patched_prepare_4d_causal_attention_mask,
|
||||||
|
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_prepare_4d_mask():
|
||||||
|
from transformers import modeling_attn_mask_utils
|
||||||
|
from transformers.models.llama import modeling_llama
|
||||||
|
|
||||||
|
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||||
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
|
)
|
||||||
|
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||||
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
|
)
|
||||||
|
modeling_llama._prepare_4d_causal_attention_mask = (
|
||||||
|
patched_prepare_4d_causal_attention_mask
|
||||||
|
)
|
||||||
|
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
|
||||||
|
patched_prepare_4d_causal_attention_mask
|
||||||
|
)
|
||||||
@@ -12,7 +12,6 @@ from torch import nn
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_embedding,
|
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
@@ -371,13 +370,13 @@ def apply_lora_kernel_patches(
|
|||||||
active_adapter = model.active_adapter
|
active_adapter = model.active_adapter
|
||||||
lora_config = model.model.peft_config[active_adapter]
|
lora_config = model.model.peft_config[active_adapter]
|
||||||
|
|
||||||
# Log what features are active
|
# Only patch if conditions are met
|
||||||
if lora_config.lora_dropout > 0:
|
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||||
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
|
||||||
if lora_config.bias != "none":
|
if not can_patch:
|
||||||
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||||
if lora_config.use_dora:
|
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||||
LOG.info("LoRA kernels: DoRA enabled")
|
return model
|
||||||
|
|
||||||
# This needs to be reset after patching
|
# This needs to be reset after patching
|
||||||
original_level = LOG.getEffectiveLevel()
|
original_level = LOG.getEffectiveLevel()
|
||||||
@@ -420,33 +419,44 @@ def apply_lora_kernel_patches(
|
|||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A") for module in layer_modules
|
hasattr(module, "lora_A")
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_qkv:
|
||||||
|
# Add optimized implementation
|
||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
"Cannot patch some attention QKV projections - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
]
|
]
|
||||||
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
can_patch_o = all(
|
||||||
|
hasattr(module, "lora_A")
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
if can_patch_o:
|
if can_patch_o:
|
||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA adapters"
|
"Cannot patch some attention output projection - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
hasattr(proj, "lora_A")
|
||||||
|
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for proj in (gate_proj, up_proj, down_proj)
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_mlp:
|
if can_patch_mlp:
|
||||||
@@ -454,50 +464,15 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters"
|
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||||
|
"lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patch embedding layers (model-level, not per-layer)
|
|
||||||
if cfg.lora_embedding_kernel:
|
|
||||||
_patch_embedding_layers(model, cfg)
|
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
|
||||||
"""Patch embedding layers with fused LoRA kernel.
|
|
||||||
|
|
||||||
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
|
||||||
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
|
||||||
"""
|
|
||||||
pretrained_model = model.model
|
|
||||||
patched = 0
|
|
||||||
|
|
||||||
# Find embedding modules - check common locations
|
|
||||||
for attr_path in [
|
|
||||||
("model", "embed_tokens"),
|
|
||||||
("model", "language_model", "embed_tokens"),
|
|
||||||
]:
|
|
||||||
parent = pretrained_model
|
|
||||||
for attr in attr_path:
|
|
||||||
parent = getattr(parent, attr, None)
|
|
||||||
if parent is None:
|
|
||||||
break
|
|
||||||
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
|
||||||
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
|
||||||
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
|
||||||
patched += 1
|
|
||||||
|
|
||||||
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
|
||||||
# when included in target_modules. No special embedding handling needed since
|
|
||||||
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
|
||||||
|
|
||||||
if not patched:
|
|
||||||
LOG.debug("No embedding layers with LoRA found to patch")
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMLP(nn.Module):
|
class FakeMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
placeholder MLP for triton patching
|
placeholder MLP for triton patching
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
|
||||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_context_parallel_inputs() -> None:
|
|
||||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
|
||||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
|
||||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
|
||||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
|
||||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
if GUARD_PATTERN not in original_source:
|
|
||||||
LOG.warning(
|
|
||||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
|
||||||
"skipping FlashAttention context parallelism patch"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
|
||||||
patched_source, _ = detab_code(patched_source)
|
|
||||||
patched_source = patched_source.replace(
|
|
||||||
"def _prepare_context_parallel_inputs(",
|
|
||||||
"def axolotl_prepare_context_parallel_inputs(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
module_name = Trainer.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# import symbols referenced in the method so exec can succeed
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in patched_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
# Use a separate namespace to capture the exec'd function
|
|
||||||
namespace = {}
|
|
||||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
|
|
||||||
exec(patched_source, namespace)
|
|
||||||
|
|
||||||
# Explicitly get the function from the namespace
|
|
||||||
axolotl_prepare_context_parallel_inputs = namespace[
|
|
||||||
"axolotl_prepare_context_parallel_inputs"
|
|
||||||
]
|
|
||||||
Trainer._original_prepare_context_parallel_inputs = (
|
|
||||||
Trainer._prepare_context_parallel_inputs
|
|
||||||
)
|
|
||||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
|
||||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
|
||||||
LOG.debug(
|
|
||||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
|
||||||
)
|
|
||||||
@@ -3,10 +3,15 @@ Shared utils for the monkeypatches
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -165,6 +170,65 @@ def set_module_name(model, name, value):
|
|||||||
setattr(parent, child_name, value)
|
setattr(parent, child_name, value)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_2d_to_4d(
|
||||||
|
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||||
|
when they attend to each other within that sequence.
|
||||||
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
binary_mask = torch.where(
|
||||||
|
mask != 0,
|
||||||
|
torch.tensor(1, device=mask.device).to(dtype),
|
||||||
|
torch.tensor(0, device=mask.device).to(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||||
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||||
|
|
||||||
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||||
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||||
|
mask.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||||
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||||
|
|
||||||
|
return masked_zero_one_mask
|
||||||
|
|
||||||
|
|
||||||
|
def patched_prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
*args,
|
||||||
|
):
|
||||||
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||||
|
return _prepare_4d_causal_attention_mask(
|
||||||
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
*args,
|
||||||
|
):
|
||||||
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||||
|
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
"""
|
|
||||||
Synthetic dataset generator for benchmarking and testing.
|
|
||||||
|
|
||||||
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
|
||||||
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
|
||||||
weighted dataset mixes.
|
|
||||||
|
|
||||||
YAML configuration example:
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: synthetic
|
|
||||||
type: _synthetic
|
|
||||||
length: 1000
|
|
||||||
sequence_length: 2048
|
|
||||||
min_input_id: 100
|
|
||||||
max_input_id: 32000
|
|
||||||
seed: 42
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
|
||||||
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sequence_length: int = 2048,
|
|
||||||
length: int = 1000,
|
|
||||||
min_input_id: int = 100,
|
|
||||||
max_input_id: int = 32000,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
):
|
|
||||||
self.sequence_length = sequence_length
|
|
||||||
self.length = length
|
|
||||||
self.min_input_id = min_input_id
|
|
||||||
self.max_input_id = max_input_id
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def wrap_dataset(
|
|
||||||
self,
|
|
||||||
dataset,
|
|
||||||
process_count: int | None = None,
|
|
||||||
keep_in_memory: bool | None = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dataset:
|
|
||||||
LOG.info(
|
|
||||||
f"Generating synthetic dataset: {self.length} samples, "
|
|
||||||
f"sequence_length={self.sequence_length}, "
|
|
||||||
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
rng = np.random.default_rng(self.seed)
|
|
||||||
input_ids = rng.integers(
|
|
||||||
low=self.min_input_id,
|
|
||||||
high=self.max_input_id,
|
|
||||||
size=(self.length, self.sequence_length),
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
attention_mask = [[1] * self.sequence_length] * self.length
|
|
||||||
# labels == input_ids means we train on all tokens
|
|
||||||
labels = [row[:] for row in input_ids]
|
|
||||||
|
|
||||||
return Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
ds_cfg = ds_cfg or {}
|
|
||||||
|
|
||||||
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
|
||||||
length = ds_cfg.get("length", 1000)
|
|
||||||
min_input_id = ds_cfg.get("min_input_id", 100)
|
|
||||||
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
|
||||||
seed = ds_cfg.get("seed", None)
|
|
||||||
|
|
||||||
return SyntheticDatasetStrategy(
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
length=length,
|
|
||||||
min_input_id=min_input_id,
|
|
||||||
max_input_id=max_input_id,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
@@ -9,6 +9,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@@ -41,6 +42,9 @@ from axolotl.utils.schemas.enums import RLType
|
|||||||
from axolotl.utils.train import determine_last_checkpoint
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
@@ -78,7 +82,7 @@ def setup_model_and_tokenizer(
|
|||||||
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||||
model, peft_config = model_loader.load()
|
model, peft_config = model_loader.load()
|
||||||
if getattr(model, "generation_config", None) is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
model_properties = model.config.to_dict()
|
model_properties = model.config.to_dict()
|
||||||
@@ -483,7 +487,7 @@ def handle_untrained_tokens_fix(
|
|||||||
def setup_model_and_trainer(
|
def setup_model_and_trainer(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
Trainer,
|
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
@@ -550,36 +554,6 @@ def setup_model_and_trainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_tui_enabled(cfg: DictDefault) -> bool:
|
|
||||||
"""Check if TUI is enabled via config or environment variable."""
|
|
||||||
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
|
|
||||||
return True
|
|
||||||
tui = cfg.get("tui")
|
|
||||||
if tui is None:
|
|
||||||
return False
|
|
||||||
if isinstance(tui, bool):
|
|
||||||
return tui
|
|
||||||
if isinstance(tui, dict):
|
|
||||||
return tui.get("enabled", False)
|
|
||||||
if hasattr(tui, "enabled"):
|
|
||||||
return tui.enabled
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tui_config(cfg: DictDefault) -> dict:
|
|
||||||
"""Extract TUI config dict from cfg."""
|
|
||||||
tui = cfg.get("tui")
|
|
||||||
if tui is None or isinstance(tui, bool):
|
|
||||||
return {"enabled": True}
|
|
||||||
if isinstance(tui, dict):
|
|
||||||
return {**tui, "enabled": True}
|
|
||||||
if hasattr(tui, "model_dump"):
|
|
||||||
d = tui.model_dump()
|
|
||||||
d["enabled"] = True
|
|
||||||
return d
|
|
||||||
return {"enabled": True}
|
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
@send_errors
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
@@ -603,37 +577,6 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Register TUI callback if enabled and rank 0
|
|
||||||
tui_enabled = _is_tui_enabled(cfg)
|
|
||||||
if tui_enabled and cfg.local_rank == 0:
|
|
||||||
from axolotl.tui import AxolotlTUICallback
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
|
|
||||||
tui_config = _get_tui_config(cfg)
|
|
||||||
tui_config_obj = (
|
|
||||||
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reuse the early-started renderer if available (started in do_train)
|
|
||||||
early_renderer = getattr(cfg, "_tui_renderer", None)
|
|
||||||
early_queue = getattr(cfg, "_tui_queue", None)
|
|
||||||
|
|
||||||
tui_callback = AxolotlTUICallback(config=tui_config_obj)
|
|
||||||
if early_renderer is not None and early_queue is not None:
|
|
||||||
# Reuse the already-running renderer and queue
|
|
||||||
tui_callback._renderer = early_renderer
|
|
||||||
tui_callback._queue = early_queue
|
|
||||||
tui_callback._renderer_started_early = True
|
|
||||||
trainer.add_callback(tui_callback)
|
|
||||||
|
|
||||||
# Stash model info so on_train_begin can emit a single unified run_info event
|
|
||||||
tui_callback._pending_run_info = {
|
|
||||||
"model_name": cfg.base_model or "",
|
|
||||||
"training_mode": str(cfg.rl) if cfg.rl else "sft",
|
|
||||||
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
|
|
||||||
}
|
|
||||||
LOG.info("TUI dashboard enabled")
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
|
|
||||||
|
|
||||||
from axolotl.tui.callback import AxolotlTUICallback
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AxolotlTUICallback",
|
|
||||||
"BasePanel",
|
|
||||||
"LineParser",
|
|
||||||
"TUIConfig",
|
|
||||||
"TUIState",
|
|
||||||
"register_panel",
|
|
||||||
"register_parser",
|
|
||||||
]
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.renderer import TUIRenderer
|
|
||||||
|
|
||||||
|
|
||||||
class _TUILogHandler(logging.Handler):
|
|
||||||
"""Logging handler that pushes log records into the TUI metric queue."""
|
|
||||||
|
|
||||||
_LEVEL_MAP = {
|
|
||||||
logging.DEBUG: "debug",
|
|
||||||
logging.INFO: "info",
|
|
||||||
logging.WARNING: "warning",
|
|
||||||
logging.ERROR: "error",
|
|
||||||
logging.CRITICAL: "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
|
|
||||||
super().__init__()
|
|
||||||
level_name = min_level.upper()
|
|
||||||
self.setLevel(getattr(logging, level_name, logging.INFO))
|
|
||||||
self._queue = metric_queue
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
try:
|
|
||||||
level = self._LEVEL_MAP.get(record.levelno, "info")
|
|
||||||
msg = self.format(record)
|
|
||||||
self._queue.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": level,
|
|
||||||
"message": msg,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
self.handleError(record)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTUICallback(TrainerCallback):
|
|
||||||
"""Pushes training metrics into a queue for the TUI renderer.
|
|
||||||
|
|
||||||
The callback never blocks on the render thread. The queue is bounded
|
|
||||||
(maxsize=512) with put_nowait; overflow is silently dropped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: TUIConfig):
|
|
||||||
self._config = config
|
|
||||||
self._queue: queue.Queue = queue.Queue(maxsize=4096)
|
|
||||||
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
|
|
||||||
self._log_handler: _TUILogHandler | None = None
|
|
||||||
self._renderer_started_early: bool = False
|
|
||||||
self._pending_run_info: dict | None = None
|
|
||||||
|
|
||||||
def _put(self, event: dict) -> None:
|
|
||||||
try:
|
|
||||||
self._queue.put_nowait(event)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
|
||||||
# Send a single unified run_info event with all fields
|
|
||||||
run_info = {
|
|
||||||
"type": "run_info",
|
|
||||||
"run_name": getattr(args, "run_name", "") or "",
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"total_epochs": float(args.num_train_epochs)
|
|
||||||
if args.num_train_epochs
|
|
||||||
else 1.0,
|
|
||||||
}
|
|
||||||
# Merge in model_name/training_mode/world_size if stashed by train.py
|
|
||||||
if self._pending_run_info:
|
|
||||||
run_info.update(self._pending_run_info)
|
|
||||||
self._pending_run_info = None
|
|
||||||
self._put(run_info)
|
|
||||||
|
|
||||||
if not self._renderer_started_early:
|
|
||||||
# Attach a logging handler to feed log messages into the events panel
|
|
||||||
self._log_handler = _TUILogHandler(
|
|
||||||
self._queue, min_level=self._config.log_level
|
|
||||||
)
|
|
||||||
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
|
||||||
# Attach to both root and axolotl loggers (axolotl has propagate=False)
|
|
||||||
logging.getLogger().addHandler(self._log_handler)
|
|
||||||
logging.getLogger("axolotl").addHandler(self._log_handler)
|
|
||||||
|
|
||||||
# Start the renderer background thread
|
|
||||||
self._renderer.start()
|
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
||||||
if logs is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Filter out non-numeric keys and internal keys
|
|
||||||
filtered = {}
|
|
||||||
for key, value in logs.items():
|
|
||||||
if key.startswith("_"):
|
|
||||||
continue
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
filtered[key] = value
|
|
||||||
elif isinstance(value, str):
|
|
||||||
# HF Trainer sometimes passes string-encoded numbers
|
|
||||||
try:
|
|
||||||
filtered[key] = float(value)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if filtered:
|
|
||||||
self._put({"type": "metrics", "logs": filtered})
|
|
||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
|
||||||
self._put(
|
|
||||||
{
|
|
||||||
"type": "step",
|
|
||||||
"step": state.global_step,
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"epoch": state.epoch if state.epoch else 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_prediction_step(self, args, state, control, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, **kwargs):
|
|
||||||
self._put({"type": "done"})
|
|
||||||
# If renderer was started early, do_train's finally block handles stop
|
|
||||||
if not self._renderer_started_early:
|
|
||||||
self._renderer.stop()
|
|
||||||
|
|
||||||
# Remove the logging handler (only if we added it)
|
|
||||||
if self._log_handler:
|
|
||||||
logging.getLogger().removeHandler(self._log_handler)
|
|
||||||
logging.getLogger("axolotl").removeHandler(self._log_handler)
|
|
||||||
self._log_handler = None
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
"""TUI configuration — Pydantic model for TUI settings."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class TUIConfig(BaseModel):
|
|
||||||
"""Configuration for the Axolotl Training TUI dashboard."""
|
|
||||||
|
|
||||||
enabled: bool = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={"description": "Enable the TUI dashboard"},
|
|
||||||
)
|
|
||||||
refresh_rate: int = Field(
|
|
||||||
default=4,
|
|
||||||
json_schema_extra={"description": "Renders per second"},
|
|
||||||
)
|
|
||||||
log_level: str = Field(
|
|
||||||
default="debug",
|
|
||||||
json_schema_extra={"description": "Minimum log level shown in events panel"},
|
|
||||||
)
|
|
||||||
panels: list[str] = Field(
|
|
||||||
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
|
|
||||||
json_schema_extra={"description": "Ordered list of panels to display"},
|
|
||||||
)
|
|
||||||
hardware_poll_interval: int = Field(
|
|
||||||
default=2,
|
|
||||||
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
|
|
||||||
)
|
|
||||||
stdout_log_path: str = Field(
|
|
||||||
default="axolotl_stdout.log",
|
|
||||||
json_schema_extra={"description": "File path for captured stdout/stderr log"},
|
|
||||||
)
|
|
||||||
parser_plugins: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
json_schema_extra={"description": "List of extra parser classes to load"},
|
|
||||||
)
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
"""GPU polling wrapper around pynvml with graceful fallback."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.tui.state import GPUStats
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_nvml_available = False
|
|
||||||
try:
|
|
||||||
import pynvml
|
|
||||||
|
|
||||||
pynvml.nvmlInit()
|
|
||||||
_nvml_available = True
|
|
||||||
except Exception:
|
|
||||||
LOG.debug("pynvml unavailable — GPU stats will not be shown")
|
|
||||||
|
|
||||||
|
|
||||||
class GPUPoller:
|
|
||||||
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._device_count = 0
|
|
||||||
if _nvml_available:
|
|
||||||
try:
|
|
||||||
self._device_count = pynvml.nvmlDeviceGetCount()
|
|
||||||
except Exception:
|
|
||||||
self._device_count = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self) -> bool:
|
|
||||||
return _nvml_available and self._device_count > 0
|
|
||||||
|
|
||||||
def poll(self) -> list[GPUStats]:
|
|
||||||
if not self.available:
|
|
||||||
return []
|
|
||||||
|
|
||||||
stats = []
|
|
||||||
for i in range(self._device_count):
|
|
||||||
try:
|
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
||||||
name = pynvml.nvmlDeviceGetName(handle)
|
|
||||||
if isinstance(name, bytes):
|
|
||||||
name = name.decode("utf-8")
|
|
||||||
|
|
||||||
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
|
||||||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
||||||
temp = pynvml.nvmlDeviceGetTemperature(
|
|
||||||
handle, pynvml.NVML_TEMPERATURE_GPU
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
|
||||||
except Exception:
|
|
||||||
power = None
|
|
||||||
|
|
||||||
stats.append(
|
|
||||||
GPUStats(
|
|
||||||
id=i,
|
|
||||||
name=name,
|
|
||||||
util_pct=util.gpu,
|
|
||||||
vram_used_gb=mem.used / (1024**3),
|
|
||||||
vram_total_gb=mem.total / (1024**3),
|
|
||||||
temp_c=temp,
|
|
||||||
power_w=power,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.debug("Error polling GPU device %d", i, exc_info=True)
|
|
||||||
return stats
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import IO
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Parser registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_parser_registry: list[type[LineParser]] = []
|
|
||||||
|
|
||||||
|
|
||||||
def register_parser(cls: type[LineParser]) -> type[LineParser]:
|
|
||||||
"""Decorator to register a LineParser subclass."""
|
|
||||||
if cls not in _parser_registry:
|
|
||||||
_parser_registry.append(cls)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_parsers() -> list[type[LineParser]]:
|
|
||||||
return list(_parser_registry)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Base LineParser
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class LineParser(ABC):
|
|
||||||
"""Base class for stdout/stderr line parsers."""
|
|
||||||
|
|
||||||
priority: int = 50
|
|
||||||
name: str = ""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
"""Parse a single captured line.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
line: one line of captured output, trailing newline stripped.
|
|
||||||
source: "stdout" or "stderr".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of event dicts to push onto the metric queue.
|
|
||||||
Return [] if this line is not relevant.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# ParserChain
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ParserChain:
|
|
||||||
def __init__(self):
|
|
||||||
self._parsers: list[LineParser] = []
|
|
||||||
|
|
||||||
def register(self, parser: LineParser) -> None:
|
|
||||||
self._parsers.append(parser)
|
|
||||||
self._parsers.sort(key=lambda p: p.priority)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str = "stdout") -> list[dict]:
|
|
||||||
events: list[dict] = []
|
|
||||||
for parser in self._parsers:
|
|
||||||
events.extend(parser.parse(line, source))
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# IOCapture — OS-level fd redirect to pipe
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class IOCapture:
|
|
||||||
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
|
|
||||||
passes lines through a ParserChain, and tees to a log file."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
|
|
||||||
):
|
|
||||||
self._parser_chain = parser_chain
|
|
||||||
self._queue = metric_queue
|
|
||||||
self._log_path = log_path
|
|
||||||
self._log_file: IO[str] | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._read_fd: int | None = None
|
|
||||||
self._write_fd: int | None = None
|
|
||||||
self._saved_stdout_fd: int | None = None
|
|
||||||
self._saved_stderr_fd: int | None = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
# Write run-start separator
|
|
||||||
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
|
|
||||||
self._log_file.write(
|
|
||||||
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
|
|
||||||
)
|
|
||||||
self._log_file.flush()
|
|
||||||
|
|
||||||
# OS-level pipe
|
|
||||||
self._read_fd, self._write_fd = os.pipe()
|
|
||||||
|
|
||||||
# Save originals
|
|
||||||
self._saved_stdout_fd = os.dup(1)
|
|
||||||
self._saved_stderr_fd = os.dup(2)
|
|
||||||
|
|
||||||
# Redirect both stdout and stderr into the write end
|
|
||||||
os.dup2(self._write_fd, 1)
|
|
||||||
os.dup2(self._write_fd, 2)
|
|
||||||
os.close(self._write_fd) # write end now held by fds 1 and 2
|
|
||||||
|
|
||||||
# Also redirect Python-level handles
|
|
||||||
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
|
|
||||||
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
|
|
||||||
|
|
||||||
# Drain thread
|
|
||||||
self._thread = threading.Thread(target=self._drain, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
# Restore fds — closes the write end, causing reader to see EOF
|
|
||||||
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
os.dup2(self._saved_stdout_fd, 1)
|
|
||||||
os.dup2(self._saved_stderr_fd, 2)
|
|
||||||
os.close(self._saved_stdout_fd)
|
|
||||||
os.close(self._saved_stderr_fd)
|
|
||||||
self._saved_stdout_fd = None
|
|
||||||
self._saved_stderr_fd = None
|
|
||||||
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=2.0)
|
|
||||||
if self._thread.is_alive():
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"IO capture thread did not exit after 2s"
|
|
||||||
)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
if self._log_file is not None:
|
|
||||||
self._log_file.close()
|
|
||||||
self._log_file = None
|
|
||||||
|
|
||||||
def _drain(self) -> None:
|
|
||||||
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
|
|
||||||
# which use \r for in-place updates without \n
|
|
||||||
assert self._read_fd is not None, "_drain called before start()"
|
|
||||||
with os.fdopen(self._read_fd, "rb") as pipe:
|
|
||||||
buf = b""
|
|
||||||
while True:
|
|
||||||
chunk = pipe.read(4096)
|
|
||||||
if not chunk:
|
|
||||||
# EOF — process remaining buffer
|
|
||||||
if buf:
|
|
||||||
self._process_line(buf.decode("utf-8", errors="replace"))
|
|
||||||
break
|
|
||||||
buf += chunk
|
|
||||||
# Split on \n or \r
|
|
||||||
while b"\n" in buf or b"\r" in buf:
|
|
||||||
# Find the earliest delimiter
|
|
||||||
idx_n = buf.find(b"\n")
|
|
||||||
idx_r = buf.find(b"\r")
|
|
||||||
if idx_n == -1:
|
|
||||||
idx = idx_r
|
|
||||||
elif idx_r == -1:
|
|
||||||
idx = idx_n
|
|
||||||
else:
|
|
||||||
idx = min(idx_n, idx_r)
|
|
||||||
line = buf[:idx].decode("utf-8", errors="replace")
|
|
||||||
buf = buf[idx + 1 :]
|
|
||||||
# Handle \r\n as single delimiter
|
|
||||||
if buf.startswith(b"\n"):
|
|
||||||
buf = buf[1:]
|
|
||||||
if line:
|
|
||||||
self._process_line(line)
|
|
||||||
|
|
||||||
def _process_line(self, line: str) -> None:
|
|
||||||
line = line.rstrip()
|
|
||||||
if not line:
|
|
||||||
return
|
|
||||||
if self._log_file:
|
|
||||||
self._log_file.write(line + "\n")
|
|
||||||
self._log_file.flush()
|
|
||||||
for event in self._parser_chain.parse(line):
|
|
||||||
try:
|
|
||||||
self._queue.put_nowait(event)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
"""Panel registry and base class for TUI panels."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Panel registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_panel_registry: dict[str, type[BasePanel]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_panel(position: str = "bottom", weight: int = 50):
|
|
||||||
"""Decorator to register a panel class with position and weight."""
|
|
||||||
|
|
||||||
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
|
|
||||||
cls.position = position
|
|
||||||
cls.weight = weight
|
|
||||||
_panel_registry[cls.name] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_panels() -> dict[str, type[BasePanel]]:
|
|
||||||
return dict(_panel_registry)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# BasePanel
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class BasePanel(ABC):
|
|
||||||
name: str = ""
|
|
||||||
position: str = "bottom"
|
|
||||||
weight: int = 50
|
|
||||||
min_height: int = 4
|
|
||||||
max_height: int | None = None
|
|
||||||
modes: list[str] = ["*"]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
"""Return a rich renderable. Called every tick."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def on_event(self, event: dict) -> None: # noqa: B027
|
|
||||||
"""Optional: react to raw metric events before state is merged."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Auto-import built-in panels to trigger registration
|
|
||||||
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
"""CompletionsPanel — shows recent RL/log_completions samples."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate(s: str, maxlen: int = 60) -> str:
|
|
||||||
return s[:maxlen] + "…" if len(s) > maxlen else s
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=20)
|
|
||||||
class CompletionsPanel(BasePanel):
|
|
||||||
name = "completions"
|
|
||||||
min_height = 6
|
|
||||||
modes = ["grpo", "dpo"]
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
if "*" not in self.modes and state.training_mode not in self.modes:
|
|
||||||
return Text("")
|
|
||||||
|
|
||||||
if not state.completions:
|
|
||||||
return Panel(
|
|
||||||
Text("No completions yet...", style="dim"),
|
|
||||||
title="Completions",
|
|
||||||
border_style="magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("step", justify="right", width=6)
|
|
||||||
table.add_column("prompt", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("completion", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("reward", justify="right", width=8)
|
|
||||||
table.add_column("adv", justify="right", width=8)
|
|
||||||
|
|
||||||
for sample in list(state.completions)[-5:]:
|
|
||||||
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
|
|
||||||
adv_str = (
|
|
||||||
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
|
|
||||||
)
|
|
||||||
table.add_row(
|
|
||||||
str(sample.step),
|
|
||||||
_truncate(sample.prompt),
|
|
||||||
_truncate(sample.completion),
|
|
||||||
reward_str,
|
|
||||||
adv_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Completions", border_style="magenta")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=30)
|
|
||||||
class DebugPanel(BasePanel):
|
|
||||||
name = "debug"
|
|
||||||
min_height = 6
|
|
||||||
max_height = 10
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
lines = Text()
|
|
||||||
# Show last 8 debug-level log lines
|
|
||||||
debug_lines = [
|
|
||||||
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
|
|
||||||
][-8:]
|
|
||||||
for log_line in debug_lines:
|
|
||||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
|
||||||
lines.append(f"[{ts}] ", style="dim")
|
|
||||||
lines.append(log_line.message[:200], style="dim")
|
|
||||||
lines.append("\n")
|
|
||||||
|
|
||||||
if not debug_lines:
|
|
||||||
lines = Text("No debug messages yet...", style="dim")
|
|
||||||
|
|
||||||
return Panel(lines, title="Debug", border_style="dim")
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""EventsPanel — scrolling log of recent events, color-coded by level."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
_LEVEL_STYLES = {
|
|
||||||
"debug": "dim",
|
|
||||||
"info": "",
|
|
||||||
"warning": "yellow",
|
|
||||||
"error": "red bold",
|
|
||||||
"critical": "red bold",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=10)
|
|
||||||
class EventsPanel(BasePanel):
|
|
||||||
name = "events"
|
|
||||||
min_height = 8
|
|
||||||
max_height = 20
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
lines = Text()
|
|
||||||
# Show last 15 non-debug log lines (debug goes to DebugPanel)
|
|
||||||
recent = [
|
|
||||||
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
|
|
||||||
][-15:]
|
|
||||||
for log_line in recent:
|
|
||||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
|
||||||
level = log_line.level.upper()
|
|
||||||
style = _LEVEL_STYLES.get(log_line.level, "")
|
|
||||||
lines.append(f"[{ts}] ", style="dim")
|
|
||||||
lines.append(f"[{level}] ", style=style or "")
|
|
||||||
lines.append(log_line.message[:200], style=style or "")
|
|
||||||
lines.append("\n")
|
|
||||||
|
|
||||||
if not recent:
|
|
||||||
lines = Text("No events yet...", style="dim")
|
|
||||||
|
|
||||||
return Panel(lines, title="Events", border_style="yellow")
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""HardwarePanel — per-GPU stats via pynvml."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
_BAR_FULL = "█"
|
|
||||||
_BAR_EMPTY = "░"
|
|
||||||
|
|
||||||
|
|
||||||
def _util_bar(pct: float, width: int = 6) -> Text:
|
|
||||||
filled = int(pct / 100 * width)
|
|
||||||
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
|
|
||||||
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
|
|
||||||
return Text.assemble((bar, color), f" {pct:3.0f}%")
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="right", weight=10)
|
|
||||||
class HardwarePanel(BasePanel):
|
|
||||||
name = "hardware"
|
|
||||||
min_height = 6
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
if not state.gpus:
|
|
||||||
return Panel(
|
|
||||||
Text("GPU stats unavailable", style="dim"),
|
|
||||||
title="Hardware",
|
|
||||||
border_style="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("id", justify="right", width=3)
|
|
||||||
table.add_column("util", no_wrap=True)
|
|
||||||
table.add_column("vram", no_wrap=True)
|
|
||||||
table.add_column("°C", justify="right", width=4)
|
|
||||||
table.add_column("W", justify="right", width=5)
|
|
||||||
|
|
||||||
total_vram_used = 0.0
|
|
||||||
total_vram_total = 0.0
|
|
||||||
total_util = 0.0
|
|
||||||
|
|
||||||
for gpu in state.gpus:
|
|
||||||
total_vram_used += gpu.vram_used_gb
|
|
||||||
total_vram_total += gpu.vram_total_gb
|
|
||||||
total_util += gpu.util_pct
|
|
||||||
|
|
||||||
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
|
|
||||||
table.add_row(
|
|
||||||
str(gpu.id),
|
|
||||||
_util_bar(gpu.util_pct),
|
|
||||||
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
|
|
||||||
str(gpu.temp_c),
|
|
||||||
power_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Footer with aggregates
|
|
||||||
n = len(state.gpus)
|
|
||||||
if n > 1:
|
|
||||||
avg_util = total_util / n
|
|
||||||
table.add_row(
|
|
||||||
"Σ",
|
|
||||||
Text(f"avg {avg_util:.0f}%", style="dim"),
|
|
||||||
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Hardware", border_style="green")
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.progress import BarColumn, Progress, TextColumn
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_time(seconds: float | None) -> str:
|
|
||||||
if seconds is None or seconds < 0:
|
|
||||||
return "--:--:--"
|
|
||||||
h = int(seconds) // 3600
|
|
||||||
m = (int(seconds) % 3600) // 60
|
|
||||||
s = int(seconds) % 60
|
|
||||||
return f"{h}:{m:02d}:{s:02d}"
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_eta(seconds: float | None) -> str:
|
|
||||||
if seconds is None or seconds < 0:
|
|
||||||
return "eta --"
|
|
||||||
h = int(seconds) // 3600
|
|
||||||
m = (int(seconds) % 3600) // 60
|
|
||||||
if h > 0:
|
|
||||||
return f"eta {h}h{m:02d}m"
|
|
||||||
return f"eta {m}m{int(seconds) % 60:02d}s"
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="top", weight=10)
|
|
||||||
class ProgressPanel(BasePanel):
|
|
||||||
name = "progress"
|
|
||||||
min_height = 3
|
|
||||||
max_height = 3
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
pct = (
|
|
||||||
(state.current_step / state.total_steps * 100)
|
|
||||||
if state.total_steps > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Header line
|
|
||||||
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
|
|
||||||
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
|
|
||||||
header = Text.assemble(
|
|
||||||
("● ", "bold green"),
|
|
||||||
("AXOLOTL", "bold cyan"),
|
|
||||||
f" {mode_upper} · {model_short} ",
|
|
||||||
(
|
|
||||||
f"{state.current_step} / {state.total_steps}",
|
|
||||||
"bold",
|
|
||||||
),
|
|
||||||
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progress bar
|
|
||||||
progress = Progress(
|
|
||||||
TextColumn(""),
|
|
||||||
BarColumn(bar_width=None),
|
|
||||||
TextColumn("{task.percentage:>3.0f}%"),
|
|
||||||
expand=True,
|
|
||||||
)
|
|
||||||
task = progress.add_task("", total=state.total_steps or 1)
|
|
||||||
progress.update(task, completed=state.current_step)
|
|
||||||
|
|
||||||
table = Table.grid(expand=True)
|
|
||||||
table.add_row(header)
|
|
||||||
table.add_row(progress)
|
|
||||||
return table
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
"""TrainingPanel — live scalar metrics table with loss sparkline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
# Braille sparkline characters (8 levels)
|
|
||||||
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
|
|
||||||
|
|
||||||
|
|
||||||
def _sparkline(values: list[float] | None, width: int = 20) -> str:
|
|
||||||
if not values or len(values) < 2:
|
|
||||||
return ""
|
|
||||||
vals = list(values)[-width:]
|
|
||||||
lo, hi = min(vals), max(vals)
|
|
||||||
rng = hi - lo if hi != lo else 1.0
|
|
||||||
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
|
|
||||||
|
|
||||||
|
|
||||||
# Known key ordering and formatting
|
|
||||||
_KNOWN_KEYS: list[tuple[str, str, str]] = [
|
|
||||||
("loss", "loss", ".4f"),
|
|
||||||
("grad_norm", "grad norm", ".3f"),
|
|
||||||
("learning_rate", "lr", ".2e"),
|
|
||||||
("tokens_per_second", "tok/s", ".1f"),
|
|
||||||
("samples_per_second", "samples/s", ".1f"),
|
|
||||||
("mfu", "MFU", ".1f"),
|
|
||||||
# RL-specific
|
|
||||||
("rewards_mean", "rewards/mean", ".4f"),
|
|
||||||
("rewards_std", "rewards/std", ".4f"),
|
|
||||||
("kl_divergence", "KL", ".4f"),
|
|
||||||
("clip_ratio", "clip ratio", ".3f"),
|
|
||||||
("queue_size", "queue", "d"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="left", weight=10)
|
|
||||||
class TrainingPanel(BasePanel):
|
|
||||||
name = "training"
|
|
||||||
min_height = 8
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("metric", style="cyan", no_wrap=True)
|
|
||||||
table.add_column("value", justify="right")
|
|
||||||
table.add_column("trend", justify="left", no_wrap=True)
|
|
||||||
|
|
||||||
for attr, label, fmt in _KNOWN_KEYS:
|
|
||||||
val = getattr(state, attr, None)
|
|
||||||
if val is None:
|
|
||||||
# Also check extra dict
|
|
||||||
val = state.extra.get(attr)
|
|
||||||
if val is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
formatted = f"{val:{fmt}}"
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted = str(val)
|
|
||||||
|
|
||||||
trend = ""
|
|
||||||
if attr == "loss":
|
|
||||||
trend = _sparkline(list(state.loss_history))
|
|
||||||
|
|
||||||
table.add_row(label, formatted, trend)
|
|
||||||
|
|
||||||
# Any extra keys not in _KNOWN_KEYS
|
|
||||||
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
|
|
||||||
for key, val in sorted(state.extra.items()):
|
|
||||||
if key in known_attrs or val is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
formatted = f"{val:.4f}"
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted = str(val)
|
|
||||||
table.add_row(key, formatted, "")
|
|
||||||
|
|
||||||
if table.row_count == 0:
|
|
||||||
return Panel(
|
|
||||||
Text("Waiting for first log step...", style="dim"),
|
|
||||||
title="Training",
|
|
||||||
border_style="blue",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Training", border_style="blue")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
|
|
||||||
|
|
||||||
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class DeepSpeedParser(LineParser):
|
|
||||||
priority = 20
|
|
||||||
name = "deepspeed"
|
|
||||||
|
|
||||||
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
|
|
||||||
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
events: list[dict] = []
|
|
||||||
if m := self._SAMPLES_RE.search(line):
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {"samples_per_second": float(m.group(1))},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if m := self._STAGE_RE.search(line):
|
|
||||||
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
|
|
||||||
return events
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class NCCLErrorParser(LineParser):
|
|
||||||
priority = 10
|
|
||||||
name = "nccl_error"
|
|
||||||
|
|
||||||
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
if self._RE.search(line):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "error",
|
|
||||||
"message": f"⚠ NCCL: {line}",
|
|
||||||
},
|
|
||||||
{"type": "alert", "severity": "error", "message": line},
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""RawLogParser — catches every line as a log_line event."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class RawLogParser(LineParser):
|
|
||||||
priority = 99
|
|
||||||
name = "raw_log"
|
|
||||||
|
|
||||||
_LOG_RE = re.compile(
|
|
||||||
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
|
|
||||||
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
|
||||||
r"\s*[-]\s*(?P<msg>.+)$",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out tqdm progress bar lines and other noisy output
|
|
||||||
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
|
|
||||||
_EMPTY_RE = re.compile(r"^\s*$")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
# Skip empty lines and tqdm progress bar updates
|
|
||||||
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
|
|
||||||
return []
|
|
||||||
|
|
||||||
m = self._LOG_RE.match(line)
|
|
||||||
level = (
|
|
||||||
m.group("level").lower()
|
|
||||||
if m
|
|
||||||
else ("error" if source == "stderr" else "info")
|
|
||||||
)
|
|
||||||
return [{"type": "log_line", "level": level, "message": line}]
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class TorchCompileParser(LineParser):
|
|
||||||
priority = 20
|
|
||||||
name = "torch_compile"
|
|
||||||
|
|
||||||
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
if self._RE.search(line):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "warning",
|
|
||||||
"message": f"⚡ compile: {line}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class TqdmParser(LineParser):
|
|
||||||
priority = 15
|
|
||||||
name = "tqdm"
|
|
||||||
|
|
||||||
# Match tqdm-style progress lines, e.g.:
|
|
||||||
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
|
|
||||||
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
|
|
||||||
# 0%| | 0/30 [00:00<?, ?it/s]
|
|
||||||
_TQDM_RE = re.compile(
|
|
||||||
r"(?P<desc>.*?)\s*"
|
|
||||||
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
|
|
||||||
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
|
|
||||||
r"\s*\[(?P<elapsed>[^\]]*)\]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also match simpler forms like:
|
|
||||||
# Fetching 0 files: 0it [00:00, ?it/s]
|
|
||||||
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
m = self._TQDM_RE.search(line)
|
|
||||||
if m:
|
|
||||||
desc = m.group("desc").strip().rstrip(":")
|
|
||||||
pct = int(m.group("pct"))
|
|
||||||
current = int(m.group("current").replace(",", ""))
|
|
||||||
total = int(m.group("total").replace(",", ""))
|
|
||||||
|
|
||||||
events: list[dict] = []
|
|
||||||
|
|
||||||
# Surface as a log line with progress info
|
|
||||||
if pct == 100 or pct == 0 or pct % 25 == 0:
|
|
||||||
msg = (
|
|
||||||
f"[{desc}] {pct}% ({current}/{total})"
|
|
||||||
if desc
|
|
||||||
else f"{pct}% ({current}/{total})"
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": msg,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also emit as a progress metric
|
|
||||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
|
||||||
if not cleaned_desc:
|
|
||||||
cleaned_desc = "progress"
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {
|
|
||||||
f"progress/{cleaned_desc}": pct / 100.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
# Fallback: try simpler fetch-style progress lines
|
|
||||||
m = self._FETCH_RE.search(line)
|
|
||||||
if m:
|
|
||||||
desc = m.group("desc").strip().rstrip(":")
|
|
||||||
current = int(m.group("current"))
|
|
||||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
|
||||||
if not cleaned_desc:
|
|
||||||
cleaned_desc = "fetch"
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] {current}" if desc else f"{current}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
return []
|
|
||||||
@@ -1,449 +0,0 @@
|
|||||||
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.live import Live
|
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.gpu import GPUPoller
|
|
||||||
from axolotl.tui.io_capture import (
|
|
||||||
IOCapture,
|
|
||||||
ParserChain,
|
|
||||||
get_registered_parsers,
|
|
||||||
)
|
|
||||||
from axolotl.tui.panels import BasePanel, get_registered_panels
|
|
||||||
from axolotl.tui.state import CompletionSample, LogLine, TUIState
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TUIRenderer:
|
|
||||||
"""Background thread that renders the TUI dashboard using rich.live.Live."""
|
|
||||||
|
|
||||||
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
|
|
||||||
self._config = config
|
|
||||||
self._queue = metric_queue
|
|
||||||
self._state = TUIState()
|
|
||||||
self._gpu_poller = GPUPoller()
|
|
||||||
self._panels: list[BasePanel] = []
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._stop_event = threading.Event()
|
|
||||||
self._io_capture: IOCapture | None = None
|
|
||||||
self._parser_chain: ParserChain | None = None
|
|
||||||
|
|
||||||
def _init_panels(self) -> None:
|
|
||||||
registry = get_registered_panels()
|
|
||||||
for panel_name in self._config.panels:
|
|
||||||
if panel_name in registry:
|
|
||||||
self._panels.append(registry[panel_name]())
|
|
||||||
|
|
||||||
def _init_parser_chain(self) -> None:
|
|
||||||
# Ensure built-in parsers are imported so @register_parser decorators fire
|
|
||||||
import axolotl.tui.parsers # noqa: F401
|
|
||||||
|
|
||||||
self._parser_chain = ParserChain()
|
|
||||||
# Register all built-in parsers
|
|
||||||
for parser_cls in get_registered_parsers():
|
|
||||||
self._parser_chain.register(parser_cls())
|
|
||||||
|
|
||||||
# Load plugin parsers
|
|
||||||
for plugin_spec in self._config.parser_plugins:
|
|
||||||
try:
|
|
||||||
if "::" in plugin_spec:
|
|
||||||
# file path :: class name
|
|
||||||
file_path, class_name = plugin_spec.split("::", 1)
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(
|
|
||||||
"custom_parser", file_path
|
|
||||||
)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ImportError(f"Cannot load spec for {file_path}")
|
|
||||||
mod = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(mod)
|
|
||||||
parser_cls = getattr(mod, class_name)
|
|
||||||
else:
|
|
||||||
# dotted module path
|
|
||||||
module_path, class_name = plugin_spec.rsplit(".", 1)
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
parser_cls = getattr(mod, class_name)
|
|
||||||
self._parser_chain.register(parser_cls())
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
|
|
||||||
|
|
||||||
def _build_layout(self) -> Layout:
|
|
||||||
layout = Layout()
|
|
||||||
|
|
||||||
top_panels = [p for p in self._panels if p.position == "top"]
|
|
||||||
left_panels = [p for p in self._panels if p.position == "left"]
|
|
||||||
right_panels = [p for p in self._panels if p.position == "right"]
|
|
||||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
|
||||||
|
|
||||||
sections = []
|
|
||||||
|
|
||||||
if top_panels:
|
|
||||||
layout_top = Layout(name="top", size=3)
|
|
||||||
sections.append(layout_top)
|
|
||||||
|
|
||||||
if left_panels or right_panels:
|
|
||||||
layout_middle = Layout(name="middle", ratio=3)
|
|
||||||
middle_parts = []
|
|
||||||
if left_panels:
|
|
||||||
middle_parts.append(Layout(name="left", ratio=1))
|
|
||||||
if right_panels:
|
|
||||||
middle_parts.append(Layout(name="right", ratio=1))
|
|
||||||
if middle_parts:
|
|
||||||
layout_middle.split_row(*middle_parts)
|
|
||||||
sections.append(layout_middle)
|
|
||||||
|
|
||||||
if bottom_panels:
|
|
||||||
layout_bottom = Layout(name="bottom", ratio=2)
|
|
||||||
if len(bottom_panels) > 1:
|
|
||||||
layout_bottom.split_row(
|
|
||||||
*[
|
|
||||||
Layout(name=f"bottom_{i}", ratio=1)
|
|
||||||
for i in range(len(bottom_panels))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
sections.append(layout_bottom)
|
|
||||||
|
|
||||||
if sections:
|
|
||||||
layout.split_column(*sections)
|
|
||||||
|
|
||||||
return layout
|
|
||||||
|
|
||||||
def _update_layout(self, layout: Layout) -> None:
|
|
||||||
top_panels = [p for p in self._panels if p.position == "top"]
|
|
||||||
left_panels = [p for p in self._panels if p.position == "left"]
|
|
||||||
right_panels = [p for p in self._panels if p.position == "right"]
|
|
||||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
|
||||||
|
|
||||||
if top_panels:
|
|
||||||
layout["top"].update(top_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if left_panels:
|
|
||||||
layout["left"].update(left_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if right_panels:
|
|
||||||
layout["right"].update(right_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if bottom_panels:
|
|
||||||
if len(bottom_panels) == 1:
|
|
||||||
layout["bottom"].update(bottom_panels[0].render(self._state))
|
|
||||||
else:
|
|
||||||
for i, panel in enumerate(bottom_panels):
|
|
||||||
layout[f"bottom_{i}"].update(panel.render(self._state))
|
|
||||||
|
|
||||||
def _drain_queue(self) -> None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
event = self._queue.get_nowait()
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Dispatch event to panels first
|
|
||||||
for panel in self._panels:
|
|
||||||
panel.on_event(event)
|
|
||||||
|
|
||||||
event_type = event.get("type")
|
|
||||||
|
|
||||||
if event_type == "metrics":
|
|
||||||
logs = event.get("logs", {})
|
|
||||||
self._apply_metrics(logs)
|
|
||||||
|
|
||||||
elif event_type == "step":
|
|
||||||
self._state.current_step = event.get("step", self._state.current_step)
|
|
||||||
self._state.total_steps = event.get(
|
|
||||||
"total_steps", self._state.total_steps
|
|
||||||
)
|
|
||||||
self._state.current_epoch = event.get(
|
|
||||||
"epoch", self._state.current_epoch
|
|
||||||
)
|
|
||||||
now = time.time()
|
|
||||||
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
|
|
||||||
if self._state.current_step > 0 and self._state.total_steps > 0:
|
|
||||||
rate = self._state.elapsed_seconds / self._state.current_step
|
|
||||||
remaining = self._state.total_steps - self._state.current_step
|
|
||||||
self._state.eta_seconds = rate * remaining
|
|
||||||
|
|
||||||
elif event_type == "log_line":
|
|
||||||
level = event.get("level", "info")
|
|
||||||
message = event.get("message", "")
|
|
||||||
self._state.log_lines.append(
|
|
||||||
LogLine(
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
level=level,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "completion":
|
|
||||||
self._state.completions.append(
|
|
||||||
CompletionSample(
|
|
||||||
step=event.get("step", 0),
|
|
||||||
prompt=event.get("prompt", ""),
|
|
||||||
completion=event.get("completion", ""),
|
|
||||||
reward=event.get("reward"),
|
|
||||||
advantage=event.get("advantage"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "run_info":
|
|
||||||
if "run_name" in event:
|
|
||||||
self._state.run_name = event["run_name"]
|
|
||||||
if "model_name" in event:
|
|
||||||
self._state.model_name = event["model_name"]
|
|
||||||
if "training_mode" in event:
|
|
||||||
self._state.training_mode = event["training_mode"]
|
|
||||||
if "world_size" in event:
|
|
||||||
self._state.world_size = event["world_size"]
|
|
||||||
if "total_steps" in event:
|
|
||||||
self._state.total_steps = event["total_steps"]
|
|
||||||
if "total_epochs" in event:
|
|
||||||
self._state.total_epochs = event["total_epochs"]
|
|
||||||
if "zero_stage" in event:
|
|
||||||
self._state.zero_stage = event["zero_stage"]
|
|
||||||
|
|
||||||
elif event_type == "done":
|
|
||||||
self._stop_event.set()
|
|
||||||
|
|
||||||
def _apply_metrics(self, logs: dict[str, Any]) -> None:
|
|
||||||
metric_map = {
|
|
||||||
"loss": "loss",
|
|
||||||
"grad_norm": "grad_norm",
|
|
||||||
"learning_rate": "learning_rate",
|
|
||||||
"tokens_per_second": "tokens_per_second",
|
|
||||||
"samples_per_second": "samples_per_second",
|
|
||||||
"mfu": "mfu",
|
|
||||||
"rewards/mean": "rewards_mean",
|
|
||||||
"rewards_mean": "rewards_mean",
|
|
||||||
"rewards/std": "rewards_std",
|
|
||||||
"rewards_std": "rewards_std",
|
|
||||||
"kl": "kl_divergence",
|
|
||||||
"kl_divergence": "kl_divergence",
|
|
||||||
"clip_ratio": "clip_ratio",
|
|
||||||
"queue_size": "queue_size",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in logs.items():
|
|
||||||
if key in metric_map:
|
|
||||||
setattr(self._state, metric_map[key], value)
|
|
||||||
else:
|
|
||||||
self._state.extra[key] = value
|
|
||||||
|
|
||||||
if "loss" in logs and logs["loss"] is not None:
|
|
||||||
self._state.loss_history.append(logs["loss"])
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
self._init_panels()
|
|
||||||
self._init_parser_chain()
|
|
||||||
|
|
||||||
# Set up I/O capture
|
|
||||||
assert self._parser_chain is not None, "_init_parser_chain must be called first"
|
|
||||||
self._io_capture = IOCapture(
|
|
||||||
log_path=self._config.stdout_log_path,
|
|
||||||
parser_chain=self._parser_chain,
|
|
||||||
metric_queue=self._queue,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Monkeypatch tqdm to suppress terminal output and route through our queue.
|
|
||||||
# This prevents tqdm progress bars from flickering through the TUI and
|
|
||||||
# ensures all progress events appear in the Events panel.
|
|
||||||
self._install_tqdm_hook()
|
|
||||||
|
|
||||||
self._io_capture_ready = threading.Event()
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
self._io_capture_ready.wait(timeout=5.0)
|
|
||||||
|
|
||||||
def _install_tqdm_hook(self) -> None:
|
|
||||||
"""Replace tqdm's display method to route updates through TUI queue."""
|
|
||||||
try:
|
|
||||||
import io
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
import tqdm.auto
|
|
||||||
|
|
||||||
q = self._queue
|
|
||||||
self._tqdm_parser = None
|
|
||||||
# Find our tqdm parser in the chain
|
|
||||||
for p in self._parser_chain._parsers if self._parser_chain else []:
|
|
||||||
if p.name == "tqdm":
|
|
||||||
self._tqdm_parser = p
|
|
||||||
break
|
|
||||||
|
|
||||||
# Save originals for restore
|
|
||||||
self._orig_tqdm_class_auto = tqdm.auto.tqdm
|
|
||||||
self._orig_tqdm_class_tqdm = tqdm.tqdm
|
|
||||||
self._orig_tqdm_class_std = tqdm.std.tqdm
|
|
||||||
|
|
||||||
class TUITqdm(tqdm.tqdm):
|
|
||||||
"""tqdm subclass that sends progress to TUI instead of terminal."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
# Force output to devnull so nothing reaches the terminal
|
|
||||||
kwargs["file"] = io.StringIO()
|
|
||||||
kwargs["dynamic_ncols"] = False
|
|
||||||
kwargs["ncols"] = 80
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def display(self, msg=None, pos=None):
|
|
||||||
# Build a progress string and push to queue
|
|
||||||
if self.total and self.total > 0:
|
|
||||||
pct = self.n / self.total * 100
|
|
||||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
|
||||||
# Emit events at milestones or at low frequency
|
|
||||||
is_milestone = (
|
|
||||||
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
|
|
||||||
)
|
|
||||||
if is_milestone:
|
|
||||||
try:
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
|
|
||||||
if desc
|
|
||||||
else f"{pct:.0f}% ({self.n}/{self.total})",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
metric_key = (
|
|
||||||
f"progress/{desc.lower().replace(' ', '_')}"
|
|
||||||
if desc
|
|
||||||
else "progress/unknown"
|
|
||||||
)
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {metric_key: pct / 100.0},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
# Emit final completion event
|
|
||||||
if self.total and self.total > 0 and self.n > 0:
|
|
||||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
|
||||||
try:
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
|
|
||||||
if desc
|
|
||||||
else f"100% ({self.total}/{self.total}) done",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
super().close()
|
|
||||||
|
|
||||||
# Replace tqdm globally
|
|
||||||
tqdm.auto.tqdm = TUITqdm
|
|
||||||
tqdm.tqdm = TUITqdm
|
|
||||||
# Also patch tqdm.std which some libraries use directly
|
|
||||||
tqdm.std.tqdm = TUITqdm
|
|
||||||
self._tui_tqdm_cls = TUITqdm
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.debug(f"Failed to install tqdm hook: {exc}")
|
|
||||||
|
|
||||||
def _uninstall_tqdm_hook(self) -> None:
|
|
||||||
"""Restore original tqdm."""
|
|
||||||
try:
|
|
||||||
import tqdm
|
|
||||||
import tqdm.auto
|
|
||||||
|
|
||||||
if hasattr(self, "_orig_tqdm_class_auto"):
|
|
||||||
tqdm.auto.tqdm = self._orig_tqdm_class_auto
|
|
||||||
if hasattr(self, "_orig_tqdm_class_tqdm"):
|
|
||||||
tqdm.tqdm = self._orig_tqdm_class_tqdm
|
|
||||||
if hasattr(self, "_orig_tqdm_class_std"):
|
|
||||||
tqdm.std.tqdm = self._orig_tqdm_class_std
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
self._stop_event.set()
|
|
||||||
self._uninstall_tqdm_hook()
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=5.0)
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
|
|
||||||
# This ensures rich.live.Live writes to the terminal, not the pipe.
|
|
||||||
saved_tty_fd = os.dup(1)
|
|
||||||
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
|
|
||||||
console = Console(file=tty_file)
|
|
||||||
|
|
||||||
layout = self._build_layout()
|
|
||||||
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
|
|
||||||
gpu_poll_counter = 0
|
|
||||||
gpu_poll_ticks = max(
|
|
||||||
1, int(self._config.hardware_poll_interval / tick_interval)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
|
|
||||||
if self._io_capture:
|
|
||||||
self._io_capture.start()
|
|
||||||
|
|
||||||
# Signal that IO capture is live so start() can return
|
|
||||||
if hasattr(self, "_io_capture_ready"):
|
|
||||||
self._io_capture_ready.set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with Live(
|
|
||||||
layout,
|
|
||||||
console=console,
|
|
||||||
refresh_per_second=self._config.refresh_rate,
|
|
||||||
screen=True,
|
|
||||||
redirect_stdout=False,
|
|
||||||
redirect_stderr=False,
|
|
||||||
) as live:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
self._drain_queue()
|
|
||||||
|
|
||||||
# Poll GPU stats periodically
|
|
||||||
gpu_poll_counter += 1
|
|
||||||
if gpu_poll_counter >= gpu_poll_ticks:
|
|
||||||
gpu_poll_counter = 0
|
|
||||||
if self._gpu_poller.available:
|
|
||||||
self._state.gpus = self._gpu_poller.poll()
|
|
||||||
|
|
||||||
# Update elapsed time
|
|
||||||
self._state.elapsed_seconds = (
|
|
||||||
time.time() - self._state.start_time.timestamp()
|
|
||||||
)
|
|
||||||
|
|
||||||
self._update_layout(layout)
|
|
||||||
live.update(layout)
|
|
||||||
|
|
||||||
time.sleep(tick_interval)
|
|
||||||
|
|
||||||
# Final drain
|
|
||||||
self._drain_queue()
|
|
||||||
self._update_layout(layout)
|
|
||||||
live.update(layout)
|
|
||||||
finally:
|
|
||||||
if self._io_capture:
|
|
||||||
self._io_capture.stop()
|
|
||||||
try:
|
|
||||||
tty_file.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
"""TUI shared data model — dataclasses for the dashboard state."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPUStats:
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
util_pct: float
|
|
||||||
vram_used_gb: float
|
|
||||||
vram_total_gb: float
|
|
||||||
temp_c: int
|
|
||||||
power_w: float | None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LogLine:
|
|
||||||
timestamp: datetime
|
|
||||||
level: str # "info" | "debug" | "warning" | "error"
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CompletionSample:
|
|
||||||
step: int
|
|
||||||
prompt: str
|
|
||||||
completion: str
|
|
||||||
reward: float | None
|
|
||||||
advantage: float | None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TUIState:
|
|
||||||
# Run metadata
|
|
||||||
run_name: str = ""
|
|
||||||
model_name: str = ""
|
|
||||||
training_mode: str = "sft"
|
|
||||||
world_size: int = 1
|
|
||||||
start_time: datetime = field(default_factory=datetime.now)
|
|
||||||
|
|
||||||
# Progress
|
|
||||||
current_step: int = 0
|
|
||||||
total_steps: int = 0
|
|
||||||
current_epoch: float = 0.0
|
|
||||||
total_epochs: float = 1.0
|
|
||||||
elapsed_seconds: float = 0.0
|
|
||||||
eta_seconds: float | None = None
|
|
||||||
|
|
||||||
# Training metrics (rolling window + current)
|
|
||||||
loss: float | None = None
|
|
||||||
grad_norm: float | None = None
|
|
||||||
learning_rate: float | None = None
|
|
||||||
tokens_per_second: float | None = None
|
|
||||||
samples_per_second: float | None = None
|
|
||||||
mfu: float | None = None
|
|
||||||
|
|
||||||
# RL-specific (None for non-RL modes)
|
|
||||||
rewards_mean: float | None = None
|
|
||||||
rewards_std: float | None = None
|
|
||||||
kl_divergence: float | None = None
|
|
||||||
clip_ratio: float | None = None
|
|
||||||
queue_size: int | None = None
|
|
||||||
|
|
||||||
# Per-GPU hardware (list indexed by local rank)
|
|
||||||
gpus: list[GPUStats] = field(default_factory=list)
|
|
||||||
|
|
||||||
# Recent log lines
|
|
||||||
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
|
|
||||||
|
|
||||||
# Recent completions (GRPO/SFT with log_completions)
|
|
||||||
completions: deque[CompletionSample] = field(
|
|
||||||
default_factory=lambda: deque(maxlen=20)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loss history for sparkline
|
|
||||||
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
|
|
||||||
|
|
||||||
# DeepSpeed zero stage (None if not using DeepSpeed)
|
|
||||||
zero_stage: int | None = None
|
|
||||||
|
|
||||||
# Arbitrary plugin state
|
|
||||||
extra: dict[str, Any] = field(default_factory=dict)
|
|
||||||
@@ -25,11 +25,9 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
|
|||||||
if (
|
if (
|
||||||
isinstance(mod, FakeQuantizedLinear)
|
isinstance(mod, FakeQuantizedLinear)
|
||||||
and mod.activation_fake_quantizer is not None
|
and mod.activation_fake_quantizer is not None
|
||||||
and hasattr(mod.activation_fake_quantizer, "enabled")
|
|
||||||
):
|
):
|
||||||
mod.activation_fake_quantizer.enabled = enable
|
mod.activation_fake_quantizer.enabled = enable
|
||||||
if hasattr(mod.weight_fake_quantizer, "enabled"):
|
mod.weight_fake_quantizer.enabled = enable
|
||||||
mod.weight_fake_quantizer.enabled = enable
|
|
||||||
|
|
||||||
|
|
||||||
class QATCallback(TrainerCallback):
|
class QATCallback(TrainerCallback):
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TOKENS_STATE_FILE = "tokens_state.json"
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondCallback(TrainerCallback):
|
class TokensPerSecondCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,12 +22,7 @@ from axolotl.utils.schemas.config import (
|
|||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
DPODataset,
|
|
||||||
KTODataset,
|
|
||||||
SFTDataset,
|
|
||||||
SyntheticDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -313,14 +308,6 @@ def validate_config(
|
|||||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||||
elif (
|
|
||||||
ds_cfg.get("type")
|
|
||||||
if isinstance(ds_cfg, dict)
|
|
||||||
else getattr(ds_cfg, "type", None)
|
|
||||||
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
|
||||||
cfg["datasets"][idx] = SyntheticDataset(
|
|
||||||
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
|
||||||
)
|
|
||||||
elif not isinstance(ds_cfg, SFTDataset):
|
elif not isinstance(ds_cfg, SFTDataset):
|
||||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -376,14 +376,10 @@ def _load_and_process_single_dataset(
|
|||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
# Load the dataset
|
||||||
if dataset_config.type == "_synthetic":
|
dataset = load_dataset_with_config(
|
||||||
dataset = Dataset.from_dict({"text": [""]})
|
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||||
else:
|
)
|
||||||
# Load the dataset
|
|
||||||
dataset = load_dataset_with_config(
|
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse dataset type
|
# Parse dataset type
|
||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|||||||
@@ -10,11 +10,9 @@ from torchao.quantization import quantize_
|
|||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
)
|
)
|
||||||
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
|
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int4WeightOnlyConfig,
|
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
Int8DynamicActivationInt4WeightConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,70 +173,6 @@ def quantize_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_qat_config(
|
|
||||||
base_config: AOBaseConfig,
|
|
||||||
weight_dtype: TorchAOQuantDType,
|
|
||||||
activation_dtype: TorchAOQuantDType | None,
|
|
||||||
group_size: int | None,
|
|
||||||
) -> QATConfig:
|
|
||||||
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
|
|
||||||
group_size and other params are properly propagated (torchao's QATConfig(base_config)
|
|
||||||
does not always map these correctly)."""
|
|
||||||
from torchao.quantization.qat.fake_quantize_config import (
|
|
||||||
Float8FakeQuantizeConfig,
|
|
||||||
IntxFakeQuantizeConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
|
||||||
return QATConfig(
|
|
||||||
activation_config=base_config,
|
|
||||||
weight_config=base_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build explicit weight config
|
|
||||||
weight_fq_config: (
|
|
||||||
Int4WeightFakeQuantizeConfig
|
|
||||||
| IntxFakeQuantizeConfig
|
|
||||||
| Float8FakeQuantizeConfig
|
|
||||||
| None
|
|
||||||
) = None
|
|
||||||
if weight_dtype == TorchAOQuantDType.int4:
|
|
||||||
gs = (
|
|
||||||
group_size
|
|
||||||
if group_size is not None
|
|
||||||
else getattr(base_config, "group_size", 128)
|
|
||||||
)
|
|
||||||
activation_dt = None
|
|
||||||
if activation_dtype == TorchAOQuantDType.int8:
|
|
||||||
activation_dt = torch.bfloat16
|
|
||||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
|
||||||
activation_dt = torch.float8_e4m3fn
|
|
||||||
kwargs = {"group_size": gs}
|
|
||||||
if activation_dt is not None:
|
|
||||||
kwargs["activation_dtype"] = activation_dt
|
|
||||||
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
|
|
||||||
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
|
|
||||||
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
# Build explicit activation config
|
|
||||||
activation_fq_config = None
|
|
||||||
if activation_dtype == TorchAOQuantDType.int8:
|
|
||||||
activation_fq_config = IntxFakeQuantizeConfig(
|
|
||||||
dtype=torch.int8, granularity="per_token", is_symmetric=False
|
|
||||||
)
|
|
||||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
|
||||||
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
if weight_fq_config is not None:
|
|
||||||
return QATConfig(
|
|
||||||
weight_config=weight_fq_config,
|
|
||||||
activation_config=activation_fq_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback to base_config for unhandled combos
|
|
||||||
return QATConfig(base_config)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_qat(
|
def prepare_model_for_qat(
|
||||||
model,
|
model,
|
||||||
weight_dtype: TorchAOQuantDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
@@ -266,9 +200,13 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=activation_dtype,
|
activation_dtype=activation_dtype,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
qat_config = _make_qat_config(
|
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||||
base_config, weight_dtype, activation_dtype, group_size
|
qat_config = QATConfig(
|
||||||
)
|
activation_config=base_config,
|
||||||
|
weight_config=base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qat_config = QATConfig(base_config)
|
||||||
quantize_(model, qat_config)
|
quantize_(model, qat_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
# activation fake quantization is not supported for embedding layers
|
# activation fake quantization is not supported for embedding layers
|
||||||
@@ -277,9 +215,12 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
embedding_qat_config = _make_qat_config(
|
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||||
embedding_base_config, weight_dtype, None, group_size
|
embedding_qat_config = QATConfig(
|
||||||
)
|
weight_config=embedding_base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_qat_config = QATConfig(embedding_base_config)
|
||||||
quantize_(
|
quantize_(
|
||||||
model,
|
model,
|
||||||
embedding_qat_config,
|
embedding_qat_config,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@@ -340,19 +340,3 @@ class JaggedLRRestartScheduler(LRScheduler):
|
|||||||
return [lr * scale for lr in original]
|
return [lr * scale for lr in original]
|
||||||
|
|
||||||
return original * scale
|
return original * scale
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return serializable state, saving inner_schedule as its own state_dict."""
|
|
||||||
state = {
|
|
||||||
key: value
|
|
||||||
for key, value in self.__dict__.items()
|
|
||||||
if key not in ("optimizer", "inner_schedule")
|
|
||||||
}
|
|
||||||
state["inner_schedule_state"] = self.inner_schedule.state_dict()
|
|
||||||
return state
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
||||||
"""Restore state, including inner_schedule."""
|
|
||||||
inner_state = state_dict.pop("inner_schedule_state")
|
|
||||||
self.__dict__.update(state_dict)
|
|
||||||
self.inner_schedule.load_state_dict(inner_state)
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.utils.datasets import get_default_process_count
|
from axolotl.utils.datasets import get_default_process_count
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
@@ -23,7 +22,6 @@ from axolotl.utils.schemas.datasets import (
|
|||||||
PretrainingDataset,
|
PretrainingDataset,
|
||||||
SFTDataset,
|
SFTDataset,
|
||||||
StepwiseSupervisedDataset,
|
StepwiseSupervisedDataset,
|
||||||
SyntheticDataset,
|
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||||
@@ -141,12 +139,6 @@ class AxolotlInputConfig(
|
|||||||
vllm: VllmConfig | None = Field(
|
vllm: VllmConfig | None = Field(
|
||||||
default_factory=lambda: VllmConfig(),
|
default_factory=lambda: VllmConfig(),
|
||||||
)
|
)
|
||||||
tui: TUIConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "TUI dashboard configuration. Set enabled: true to activate."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
qat: QATConfig | None = None
|
qat: QATConfig | None = None
|
||||||
quantization: PTQConfig | None = None
|
quantization: PTQConfig | None = None
|
||||||
reward_model: bool | None = Field(
|
reward_model: bool | None = Field(
|
||||||
@@ -193,13 +185,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[
|
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||||
SFTDataset
|
|
||||||
| DPODataset
|
|
||||||
| KTODataset
|
|
||||||
| StepwiseSupervisedDataset
|
|
||||||
| SyntheticDataset
|
|
||||||
],
|
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -212,13 +198,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
test_datasets: (
|
test_datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[
|
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||||
SFTDataset
|
|
||||||
| DPODataset
|
|
||||||
| KTODataset
|
|
||||||
| StepwiseSupervisedDataset
|
|
||||||
| SyntheticDataset
|
|
||||||
],
|
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -453,12 +433,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
layer_offloading: bool | None = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = Field(
|
unfrozen_parameters: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -710,12 +684,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_embedding_kernel: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
chunked_cross_entropy: bool | None = Field(
|
chunked_cross_entropy: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -1326,7 +1294,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
or data.get("lora_embedding_kernel")
|
|
||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp_config") is not None
|
is_fsdp = data.get("fsdp_config") is not None
|
||||||
@@ -1374,12 +1341,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
kernel_fields = [
|
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||||
"lora_mlp_kernel",
|
|
||||||
"lora_qkv_kernel",
|
|
||||||
"lora_o_kernel",
|
|
||||||
"lora_embedding_kernel",
|
|
||||||
]
|
|
||||||
if (
|
if (
|
||||||
any(data.get(k) is not None for k in kernel_fields)
|
any(data.get(k) is not None for k in kernel_fields)
|
||||||
or any(data.get(k) for k in unsloth_fields)
|
or any(data.get(k) for k in unsloth_fields)
|
||||||
@@ -1392,6 +1354,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("trust_remote_code"):
|
if data.get("trust_remote_code"):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||||
|
if data.get("lora_dropout") != 0:
|
||||||
|
return data
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
@@ -1413,9 +1379,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("lora_o_kernel") is None:
|
if data.get("lora_o_kernel") is None:
|
||||||
data["lora_o_kernel"] = True
|
data["lora_o_kernel"] = True
|
||||||
|
|
||||||
if data.get("lora_embedding_kernel") is None:
|
|
||||||
data["lora_embedding_kernel"] = True
|
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||||
|
|||||||
@@ -296,42 +296,4 @@ class KTODataset(BaseModel):
|
|||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataset(BaseModel):
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||||
"""Synthetic dataset configuration for benchmarking and testing.
|
|
||||||
|
|
||||||
Generates datasets with configurable sequence length, dataset size, and token ID
|
|
||||||
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
|
||||||
validating weighted dataset mixes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
path: Literal["synthetic"] = "synthetic"
|
|
||||||
type: Literal["_synthetic"] = "_synthetic"
|
|
||||||
length: int = Field(
|
|
||||||
default=1000,
|
|
||||||
json_schema_extra={"description": "Number of rows to generate"},
|
|
||||||
)
|
|
||||||
sequence_length: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Sequence length per row (defaults to sequence_len from config)"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
min_input_id: int = Field(
|
|
||||||
default=100,
|
|
||||||
json_schema_extra={"description": "Minimum token ID for generation"},
|
|
||||||
)
|
|
||||||
max_input_id: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
seed: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Random seed for reproducibility"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DatasetConfig = (
|
|
||||||
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -87,11 +87,6 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
came_pytorch = "came_pytorch"
|
came_pytorch = "came_pytorch"
|
||||||
muon = "muon"
|
muon = "muon"
|
||||||
dion = "dion"
|
dion = "dion"
|
||||||
flash_adamw = "flash_adamw"
|
|
||||||
flash_adam = "flash_adam"
|
|
||||||
flash_sgd = "flash_sgd"
|
|
||||||
flash_sgdw = "flash_sgdw"
|
|
||||||
flash_lion = "flash_lion"
|
|
||||||
|
|
||||||
|
|
||||||
class RingAttnFunc(str, Enum):
|
class RingAttnFunc(str, Enum):
|
||||||
|
|||||||
@@ -681,7 +681,15 @@ class LoRAValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_dora(cls, data):
|
def check_lora_kernels_dora(cls, data):
|
||||||
# DoRA is now supported by lora kernels
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
) and data.get("peft_use_dora"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||||
|
"compatible with DoRA at the moment."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -782,14 +790,6 @@ class OptimizationValidationMixin:
|
|||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_fsdp_version(data):
|
|
||||||
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
|
|
||||||
fsdp_version = data.get("fsdp_version")
|
|
||||||
if fsdp_version is None:
|
|
||||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
|
||||||
return fsdp_version
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_muon_deepspeed_fsdp(cls, data):
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
@@ -799,32 +799,15 @@ class OptimizationValidationMixin:
|
|||||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||||
)
|
)
|
||||||
if data.get("fsdp") or data.get("fsdp_config"):
|
if data.get("fsdp") or data.get("fsdp_config"):
|
||||||
fsdp_version = cls._resolve_fsdp_version(data)
|
fsdp_version = data.get("fsdp_version")
|
||||||
|
if fsdp_version is None:
|
||||||
|
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||||
if str(fsdp_version) != "2":
|
if str(fsdp_version) != "2":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_flashoptim_deepspeed_fsdp(cls, data):
|
|
||||||
optimizer = data.get("optimizer") or ""
|
|
||||||
if str(optimizer).startswith("flash_"):
|
|
||||||
if data.get("deepspeed"):
|
|
||||||
raise ValueError(
|
|
||||||
f"{optimizer} optimizer is incompatible with DeepSpeed. "
|
|
||||||
"Flash optimizers only support DDP and FSDP2."
|
|
||||||
)
|
|
||||||
if data.get("fsdp") or data.get("fsdp_config"):
|
|
||||||
fsdp_version = cls._resolve_fsdp_version(data)
|
|
||||||
if str(fsdp_version) != "2":
|
|
||||||
raise ValueError(
|
|
||||||
f"{optimizer} optimizer is only compatible with FSDP2. "
|
|
||||||
"Set fsdp_version: 2 to use flash optimizers with FSDP."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_batch_flattening_fa(cls, data):
|
def check_batch_flattening_fa(cls, data):
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ import datasets
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import transformers.utils as _transformers_utils
|
|
||||||
import transformers.utils.import_utils as _import_utils
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
@@ -31,26 +29,6 @@ from tests.hf_offline_utils import (
|
|||||||
|
|
||||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
# Shim for deepseek v3
|
|
||||||
if not hasattr(_import_utils, "is_torch_fx_available"):
|
|
||||||
|
|
||||||
def _is_torch_fx_available():
|
|
||||||
try:
|
|
||||||
import torch.fx # noqa: F401 # pylint: disable=unused-import
|
|
||||||
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
_import_utils.is_torch_fx_available = _is_torch_fx_available
|
|
||||||
|
|
||||||
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
|
|
||||||
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
|
|
||||||
|
|
||||||
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
|
|
||||||
_is_flash_attn_gte("2.10")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
|
|
||||||
proj.base_layer = base_layer
|
proj.base_layer = base_layer
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||||
# quant_state should be None since weight is bf16, not FP8
|
# quant_state should be None since weight is bf16, not FP8
|
||||||
self.assertIsNone(quant_state)
|
self.assertIsNone(quant_state)
|
||||||
|
|
||||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
scale_inv = torch.ones(1)
|
scale_inv = torch.ones(1)
|
||||||
base_layer.weight_scale_inv = scale_inv
|
base_layer.weight_scale_inv = scale_inv
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||||
self.assertIs(quant_state, scale_inv)
|
self.assertIs(quant_state, scale_inv)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ Test strategy:
|
|||||||
- Tolerances account for tf32 accumulation in Triton kernels
|
- Tolerances account for tf32 accumulation in Triton kernels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -35,21 +34,6 @@ pytestmark = pytest.mark.skipif(
|
|||||||
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
||||||
|
|
||||||
|
|
||||||
def skip_on_out_of_resources(func):
|
|
||||||
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
|
||||||
if "OutOfResources" in type(exc).__name__:
|
|
||||||
pytest.skip(f"GPU shared memory too small: {exc}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Helpers
|
# Helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -225,7 +209,6 @@ def make_test_data(
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestForwardPass:
|
class TestForwardPass:
|
||||||
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
||||||
|
|
||||||
@@ -305,7 +288,6 @@ class TestForwardPass:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestForwardGrouped:
|
class TestForwardGrouped:
|
||||||
"""Test forward pass with grouped_in/grouped_out configurations."""
|
"""Test forward pass with grouped_in/grouped_out configurations."""
|
||||||
|
|
||||||
@@ -395,7 +377,6 @@ class TestForwardGrouped:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestLoRAGradients:
|
class TestLoRAGradients:
|
||||||
"""Test backward LoRA gradient computation (dA, dB)."""
|
"""Test backward LoRA gradient computation (dA, dB)."""
|
||||||
|
|
||||||
@@ -471,7 +452,6 @@ class TestLoRAGradients:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestAutograd:
|
class TestAutograd:
|
||||||
"""Test full autograd integration through ScatterMoELoRA."""
|
"""Test full autograd integration through ScatterMoELoRA."""
|
||||||
|
|
||||||
@@ -640,7 +620,6 @@ class TestAutograd:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestBaseEquivalence:
|
class TestBaseEquivalence:
|
||||||
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
||||||
|
|
||||||
@@ -713,7 +692,6 @@ class TestBaseEquivalence:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestLoRAAdditivity:
|
class TestLoRAAdditivity:
|
||||||
"""Test that the LoRA component is correctly additive."""
|
"""Test that the LoRA component is correctly additive."""
|
||||||
|
|
||||||
@@ -771,7 +749,6 @@ class TestLoRAAdditivity:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestParallelExpertsModule:
|
class TestParallelExpertsModule:
|
||||||
"""Test the ParallelExperts module with LoRA."""
|
"""Test the ParallelExperts module with LoRA."""
|
||||||
|
|
||||||
@@ -839,7 +816,6 @@ class TestParallelExpertsModule:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestEdgeCases:
|
class TestEdgeCases:
|
||||||
"""Edge cases and boundary conditions."""
|
"""Edge cases and boundary conditions."""
|
||||||
|
|
||||||
@@ -937,7 +913,6 @@ class TestEdgeCases:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestFusedDX:
|
class TestFusedDX:
|
||||||
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
||||||
|
|
||||||
@@ -1005,7 +980,6 @@ class TestFusedDX:
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||||
|
|
||||||
@skip_on_out_of_resources
|
|
||||||
def test_large(self):
|
def test_large(self):
|
||||||
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||||
|
|
||||||
@@ -1148,7 +1122,6 @@ class TestFusedDX:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestFusedGatherBackward:
|
class TestFusedGatherBackward:
|
||||||
"""Test fused gather + backward dA/dB kernel."""
|
"""Test fused gather + backward dA/dB kernel."""
|
||||||
|
|
||||||
@@ -1201,7 +1174,6 @@ class TestFusedGatherBackward:
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||||
|
|
||||||
@skip_on_out_of_resources
|
|
||||||
def test_large(self):
|
def test_large(self):
|
||||||
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||||
|
|
||||||
@@ -1211,7 +1183,6 @@ class TestFusedGatherBackward:
|
|||||||
def test_k1(self):
|
def test_k1(self):
|
||||||
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
||||||
|
|
||||||
@skip_on_out_of_resources
|
|
||||||
def test_many_experts(self):
|
def test_many_experts(self):
|
||||||
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
||||||
|
|
||||||
@@ -1298,8 +1269,6 @@ class TestFusedGatherBackward:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
@pytest.mark.xfail(reason="flaky", strict=False)
|
|
||||||
class TestTokenRounding:
|
class TestTokenRounding:
|
||||||
"""Test token rounding utility and its integration with backward kernels."""
|
"""Test token rounding utility and its integration with backward kernels."""
|
||||||
|
|
||||||
@@ -1346,7 +1315,6 @@ class TestTokenRounding:
|
|||||||
)
|
)
|
||||||
prev = padded_offsets[e].item()
|
prev = padded_offsets[e].item()
|
||||||
|
|
||||||
@skip_on_out_of_resources
|
|
||||||
def test_round_with_fused_gather(self):
|
def test_round_with_fused_gather(self):
|
||||||
"""Token rounding + fused gather gives same result as plain fused gather."""
|
"""Token rounding + fused gather gives same result as plain fused gather."""
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@@ -1446,7 +1414,6 @@ class TestTokenRounding:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestCombinedOptimizations:
|
class TestCombinedOptimizations:
|
||||||
"""Test all optimizations together."""
|
"""Test all optimizations together."""
|
||||||
|
|
||||||
@@ -1616,7 +1583,6 @@ def _make_mock_sigmoid_moe_block(
|
|||||||
return moe_block, T, H, FF, E, K
|
return moe_block, T, H, FF, E, K
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestHFScatterMoESigmoidRouting:
|
class TestHFScatterMoESigmoidRouting:
|
||||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||||
|
|
||||||
@@ -1758,7 +1724,6 @@ class TestHFScatterMoESigmoidRouting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||||
|
|
||||||
|
|||||||
@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
|
|||||||
def _get_repo_path():
|
def _get_repo_path():
|
||||||
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
||||||
return (
|
return (
|
||||||
Path(__file__).parent.parent.parent.parent
|
Path(__file__).parent.parent.parent
|
||||||
/ "src"
|
/ "src"
|
||||||
/ "axolotl"
|
/ "axolotl"
|
||||||
/ "integrations"
|
/ "integrations"
|
||||||
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
|
|||||||
|
|
||||||
# Kernelize
|
# Kernelize
|
||||||
repo_path = (
|
repo_path = (
|
||||||
Path(__file__).parent.parent.parent.parent
|
Path(__file__).parent.parent.parent
|
||||||
/ "src"
|
/ "src"
|
||||||
/ "axolotl"
|
/ "axolotl"
|
||||||
/ "integrations"
|
/ "integrations"
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def mock_proj():
|
|||||||
def test_get_lora_parameters(mock_proj):
|
def test_get_lora_parameters(mock_proj):
|
||||||
"""Tests get_lora_parameters function"""
|
"""Tests get_lora_parameters function"""
|
||||||
# Test with LoRA enabled
|
# Test with LoRA enabled
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
|
|
||||||
assert isinstance(W, torch.Tensor)
|
assert isinstance(W, torch.Tensor)
|
||||||
assert W.shape == (128, 64)
|
assert W.shape == (128, 64)
|
||||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
|||||||
|
|
||||||
# Test with LoRA disabled
|
# Test with LoRA disabled
|
||||||
mock_proj.disable_adapters = True
|
mock_proj.disable_adapters = True
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
# Test with merged state
|
# Test with merged state
|
||||||
mock_proj.disable_adapters = False
|
mock_proj.disable_adapters = False
|
||||||
mock_proj.merged = True
|
mock_proj.merged = True
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
|||||||
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
|
||||||
|
|
||||||
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
|
||||||
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
|
||||||
with bias, dropout, and DoRA enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from tests.e2e.utils import require_torch_2_7_0
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
def _run_training(temp_dir, cfg):
|
|
||||||
"""Write config and launch multi-GPU training."""
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
|
||||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
|
||||||
cfg = {
|
|
||||||
"base_model": "Qwen/Qwen3-0.6B",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 3,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 1e-4,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"bf16": True,
|
|
||||||
"fsdp_version": 2,
|
|
||||||
"fsdp_config": {
|
|
||||||
"offload_params": False,
|
|
||||||
"cpu_ram_efficient_loading": False,
|
|
||||||
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
|
||||||
"state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"reshard_after_forward": True,
|
|
||||||
},
|
|
||||||
# Enable all LoRA kernels
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
"lora_embedding_kernel": True,
|
|
||||||
"save_safetensors": True,
|
|
||||||
}
|
|
||||||
cfg.update(overrides)
|
|
||||||
return DictDefault(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFSDP2LoRAKernels:
|
|
||||||
"""Test LoRA kernels under FSDP2."""
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_basic(self, temp_dir):
|
|
||||||
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dropout(self, temp_dir):
|
|
||||||
"""LoRA kernels + dropout + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dora(self, temp_dir):
|
|
||||||
"""LoRA kernels + DoRA + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
|
||||||
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(
|
|
||||||
temp_dir,
|
|
||||||
peft_use_dora=True,
|
|
||||||
lora_dropout=0.05,
|
|
||||||
)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_patch_conditions():
|
def test_kernel_patch_conditions():
|
||||||
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
"""Test various conditions that should prevent kernel patching."""
|
||||||
test_configs = [
|
test_configs = [
|
||||||
# Dropout — kernels now support this
|
# Dropout prevents patching
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"bias": "none",
|
"bias": "none",
|
||||||
},
|
},
|
||||||
# Bias — kernels now support this
|
# Bias prevents patching
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -252,14 +252,13 @@ def test_kernel_patch_conditions():
|
|||||||
model = PeftModelForCausalLM(model, peft_config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
|
|
||||||
|
# Should not patch
|
||||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
layer = patched_model.model.model.layers[0].mlp
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
|
|
||||||
# Verify patches ARE applied (dropout and bias are now supported)
|
# Verify no patches applied
|
||||||
assert (
|
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||||
layer.forward.__func__ is apply_lora_mlp_swiglu
|
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||||
or layer.forward.__func__ is apply_lora_mlp_geglu
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_config_options():
|
def test_kernel_config_options():
|
||||||
@@ -512,7 +511,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
"""Test model loading with dropout non-zero should not patch."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -547,18 +546,31 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
|||||||
# Load config
|
# Load config
|
||||||
cfg = load_cfg(str(path))
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
|
# Get original attention class
|
||||||
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
# Store original state before patching
|
||||||
|
original_forward_method = attention_cls.forward
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
|
# We call modelloader as that's where the patches are applied
|
||||||
|
# despite the fact that we're not using it to load the model
|
||||||
model_loader = ModelLoader(cfg, tokenizer)
|
model_loader = ModelLoader(cfg, tokenizer)
|
||||||
|
|
||||||
# Apply patches — should succeed even with dropout > 0
|
# Apply patch
|
||||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||||
|
|
||||||
|
# Verify patch was not applied
|
||||||
|
assert attention_cls.forward == original_forward_method
|
||||||
|
|
||||||
|
# Apply apply_lora_kernel_patches
|
||||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||||
|
|
||||||
# Verify patches WERE applied (dropout is now supported by kernels)
|
# Verify patch was not applied
|
||||||
layers = get_layers(model)
|
layers = get_layers(model)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for self_attn in find_self_attn_in_layer(layer):
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
assert hasattr(self_attn, "apply_qkv")
|
assert not hasattr(self_attn, "apply_qkv")
|
||||||
assert hasattr(self_attn, "apply_o")
|
assert not hasattr(self_attn, "apply_o")
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import subprocess
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
|
|
||||||
)
|
|
||||||
class TestDeepseekV3:
|
class TestDeepseekV3:
|
||||||
"""
|
"""
|
||||||
Test case for DeepseekV3 models
|
Test case for DeepseekV3 models
|
||||||
|
|||||||
@@ -262,7 +262,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_orpo_lora(self, temp_dir):
|
def test_orpo_lora(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ E2E tests for custom optimizers using Llama
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
@@ -284,60 +282,3 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"optimizer_name,expected_class,learning_rate",
|
|
||||||
[
|
|
||||||
("flash_adamw", "FlashAdamW", 0.00001),
|
|
||||||
("flash_adam", "FlashAdam", 0.00001),
|
|
||||||
("flash_sgd", "FlashSGD", 0.01),
|
|
||||||
("flash_sgdw", "FlashSGDW", 0.01),
|
|
||||||
("flash_lion", "FlashLion", 0.0001),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
|
|
||||||
pytest.importorskip("flashoptim")
|
|
||||||
temp_dir = str(tmp_path)
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"model_type": "AutoModelForCausalLM",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"optimizer": optimizer_name,
|
|
||||||
"max_steps": 5,
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class
|
|
||||||
|
|||||||
@@ -35,14 +35,6 @@ from tests.e2e.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_fake_quant_config_dtype(config):
|
|
||||||
"""Get the weight dtype from a fake quantize config, handling different config types."""
|
|
||||||
if hasattr(config, "dtype"):
|
|
||||||
return config.dtype
|
|
||||||
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
|
|
||||||
return torch.int4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def model():
|
def model():
|
||||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -165,18 +157,6 @@ class TestQuantization:
|
|||||||
expected_exception,
|
expected_exception,
|
||||||
expected_tensor_class,
|
expected_tensor_class,
|
||||||
):
|
):
|
||||||
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
|
|
||||||
# (see https://pypi.org/project/mslk-cuda/)
|
|
||||||
if expected_tensor_class is Int4Tensor and activation_dtype is None:
|
|
||||||
try:
|
|
||||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
|
|
||||||
int4_row_quantize_zp,
|
|
||||||
)
|
|
||||||
|
|
||||||
if int4_row_quantize_zp is None:
|
|
||||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
|
||||||
if expected_exception:
|
if expected_exception:
|
||||||
with pytest.raises(expected_exception):
|
with pytest.raises(expected_exception):
|
||||||
quantize_model(
|
quantize_model(
|
||||||
@@ -272,24 +252,28 @@ class TestQuantization:
|
|||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||||
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
|
assert (
|
||||||
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
|
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||||
|
== weight_dtype.value
|
||||||
|
)
|
||||||
if group_size:
|
if group_size:
|
||||||
assert embed_config.group_size == group_size
|
assert (
|
||||||
|
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||||
|
== group_size
|
||||||
|
)
|
||||||
|
|
||||||
for child in list(model.children()):
|
for child in list(model.children()):
|
||||||
if isinstance(child, torch.nn.Linear):
|
if isinstance(child, torch.nn.Linear):
|
||||||
assert isinstance(child, FakeQuantizedLinear)
|
assert isinstance(child, FakeQuantizedLinear)
|
||||||
assert hasattr(child, "weight_fake_quantizer")
|
assert hasattr(child, "weight_fake_quantizer")
|
||||||
w_config = child.weight_fake_quantizer.config
|
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||||
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
|
|
||||||
if group_size:
|
if group_size:
|
||||||
assert w_config.group_size == group_size
|
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||||
if activation_dtype:
|
if activation_dtype:
|
||||||
assert hasattr(child, "activation_fake_quantizer")
|
assert hasattr(child, "activation_fake_quantizer")
|
||||||
a_config = child.activation_fake_quantizer.config
|
|
||||||
assert (
|
assert (
|
||||||
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
|
child.activation_fake_quantizer.config.dtype
|
||||||
|
== activation_dtype.value
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert child.activation_fake_quantizer is None
|
assert child.activation_fake_quantizer is None
|
||||||
@@ -390,16 +374,9 @@ class TestQuantizationCallback:
|
|||||||
|
|
||||||
# ensure model has been quantized
|
# ensure model has been quantized
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
# Only test enable/disable toggling if the fake quantizer supports it
|
|
||||||
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
|
|
||||||
supports_toggle = hasattr(
|
|
||||||
model.model.embed_tokens.weight_fake_quantizer, "enabled"
|
|
||||||
)
|
|
||||||
if supports_toggle:
|
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
|
||||||
qat_callback = QATCallback(cfg)
|
qat_callback = QATCallback(cfg)
|
||||||
|
|
||||||
@@ -411,10 +388,9 @@ class TestQuantizationCallback:
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if supports_toggle:
|
# quantization should have been disabled
|
||||||
# quantization should have been disabled
|
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
|
||||||
trainer_state.global_step = 100
|
trainer_state.global_step = 100
|
||||||
qat_callback.on_step_begin(
|
qat_callback.on_step_begin(
|
||||||
@@ -424,10 +400,9 @@ class TestQuantizationCallback:
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if supports_toggle:
|
# quantization should have been enabled
|
||||||
# quantization should have been enabled
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||||
@@ -449,10 +424,9 @@ class TestQuantizationCallback:
|
|||||||
|
|
||||||
# ensure model has been quantized
|
# ensure model has been quantized
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
|
||||||
qat_callback = QATCallback(cfg)
|
qat_callback = QATCallback(cfg)
|
||||||
# simulate first training step
|
# simulate first training step
|
||||||
@@ -464,6 +438,5 @@ class TestQuantizationCallback:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# quantization should be enabled from the get-go
|
# quantization should be enabled from the get-go
|
||||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ def check_tensorboard(
|
|||||||
tag: str,
|
tag: str,
|
||||||
lt_val: float,
|
lt_val: float,
|
||||||
assertion_err: str,
|
assertion_err: str,
|
||||||
rtol: float = 0.05,
|
rtol: float = 0.02,
|
||||||
gt_zero: bool = True,
|
gt_zero: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,229 +0,0 @@
|
|||||||
"""
|
|
||||||
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
|
||||||
|
|
||||||
Tests against the eager Qwen3_5RMSNormGated implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
pytest.importorskip("triton", reason="triton required for fused kernels")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
|
||||||
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
|
|
||||||
class EagerRMSNormGated(torch.nn.Module):
|
|
||||||
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states, gate=None):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
hidden_states = self.weight * hidden_states.to(input_dtype)
|
|
||||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
|
||||||
return hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_weights(eager_mod, fused_mod):
|
|
||||||
"""Copy weights from eager to fused module."""
|
|
||||||
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"shape",
|
|
||||||
[
|
|
||||||
(2, 128, 256),
|
|
||||||
(4, 64, 512),
|
|
||||||
(1, 32, 1024),
|
|
||||||
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
|
||||||
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
|
||||||
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
|
||||||
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
|
||||||
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
|
||||||
],
|
|
||||||
)
|
|
||||||
class TestRMSNormGatedForward:
|
|
||||||
def test_output_matches_eager(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X, gate=G)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
|
||||||
else:
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
|
|
||||||
def test_output_shape(self, dtype, shape):
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
y = fused(X, gate=G)
|
|
||||||
assert y.shape == (B, T, H)
|
|
||||||
assert y.dtype == dtype
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"shape",
|
|
||||||
[
|
|
||||||
(2, 32, 256),
|
|
||||||
(2, 16, 512),
|
|
||||||
(2, 16, 2560), # Qwen3.5-4B
|
|
||||||
(1, 8, 4096), # Qwen3.5-9B
|
|
||||||
(1, 8, 5120), # Qwen3.5-27B
|
|
||||||
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
|
||||||
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
class TestRMSNormGatedBackward:
|
|
||||||
def test_grad_x(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
def test_grad_gate(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
def test_grad_weight(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(
|
|
||||||
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRMSNormGatedEdgeCases:
|
|
||||||
def test_gate_none_raises(self):
|
|
||||||
fused = FusedRMSNormGated(256).cuda()
|
|
||||||
X = torch.randn(2, 4, 256, device="cuda")
|
|
||||||
with pytest.raises(ValueError, match="requires a gate tensor"):
|
|
||||||
fused(X, gate=None)
|
|
||||||
|
|
||||||
def test_2d_input(self):
|
|
||||||
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
|
||||||
torch.manual_seed(42)
|
|
||||||
H = 512
|
|
||||||
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
|
||||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
|
||||||
|
|
||||||
def test_random_weight_init(self):
|
|
||||||
"""Test with non-default weight values."""
|
|
||||||
torch.manual_seed(123)
|
|
||||||
H = 256
|
|
||||||
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
|
||||||
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
# Randomize weights
|
|
||||||
eager.weight.data = torch.randn_like(eager.weight.data)
|
|
||||||
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X, gate=G)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""Tests for the HF Trainer context parallel patch."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
|
||||||
GUARD_PATTERN,
|
|
||||||
PATCHED_GUARD,
|
|
||||||
patch_prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def restore_trainer_prepare_method():
|
|
||||||
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
|
||||||
original_method = getattr(
|
|
||||||
Trainer,
|
|
||||||
"_original_prepare_context_parallel_inputs",
|
|
||||||
Trainer._prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
patched_attr_present = hasattr(
|
|
||||||
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
Trainer._prepare_context_parallel_inputs = original_method
|
|
||||||
if patched_attr_present:
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
|
||||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
|
||||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
|
||||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
|
||||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
|
||||||
# Ensure we start from the unpatched method
|
|
||||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
|
||||||
Trainer._prepare_context_parallel_inputs = (
|
|
||||||
Trainer._original_prepare_context_parallel_inputs
|
|
||||||
)
|
|
||||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
|
||||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
|
||||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
|
|
||||||
patched_method = Trainer._prepare_context_parallel_inputs
|
|
||||||
assert patched_method is not None
|
|
||||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
|
||||||
|
|
||||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
|
||||||
assert GUARD_PATTERN not in source
|
|
||||||
assert PATCHED_GUARD in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
|
||||||
"""Calling the patch twice should leave the same patched function in place."""
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
first_patched = Trainer._prepare_context_parallel_inputs
|
|
||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
second_patched = Trainer._prepare_context_parallel_inputs
|
|
||||||
|
|
||||||
assert first_patched is second_patched
|
|
||||||
@@ -13,7 +13,6 @@ from axolotl.utils.config import validate_config
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.schemas.datasets import SFTDataset
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
@@ -1732,52 +1731,3 @@ class TestDataloaderValidation(BaseValidation):
|
|||||||
assert new_cfg.dataloader_num_workers == 8
|
assert new_cfg.dataloader_num_workers == 8
|
||||||
assert new_cfg.dataloader_pin_memory is True
|
assert new_cfg.dataloader_pin_memory is True
|
||||||
assert new_cfg.dataloader_prefetch_factor == 256
|
assert new_cfg.dataloader_prefetch_factor == 256
|
||||||
|
|
||||||
|
|
||||||
class TestSyntheticDatasetValidation(BaseValidation):
|
|
||||||
"""
|
|
||||||
Tests for synthetic dataset config validation
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_cfg(minimal_cfg, datasets):
|
|
||||||
raw = dict(minimal_cfg)
|
|
||||||
raw["datasets"] = datasets
|
|
||||||
return DictDefault(raw)
|
|
||||||
|
|
||||||
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
|
||||||
"""Synthetic dataset passed as a raw dict should not raise."""
|
|
||||||
cfg = self._make_cfg(
|
|
||||||
minimal_cfg,
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"path": "synthetic",
|
|
||||||
"type": "_synthetic",
|
|
||||||
"length": 100,
|
|
||||||
"sequence_length": 64,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
|
||||||
|
|
||||||
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
|
||||||
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
|
||||||
sft = SFTDataset(path="synthetic", type="_synthetic")
|
|
||||||
cfg = self._make_cfg(minimal_cfg, [sft])
|
|
||||||
|
|
||||||
# Before the fix, this raised:
|
|
||||||
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
|
||||||
|
|
||||||
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
|
||||||
"""A regular SFT dataset should validate without being treated as synthetic."""
|
|
||||||
cfg = self._make_cfg(
|
|
||||||
minimal_cfg,
|
|
||||||
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
|
||||||
)
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
|
||||||
|
|||||||
@@ -1,125 +0,0 @@
|
|||||||
"""Tests for the synthetic dataset generator."""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
class TestSyntheticDatasetStrategy(unittest.TestCase):
|
|
||||||
def test_generates_correct_shape(self):
|
|
||||||
strategy = SyntheticDatasetStrategy(
|
|
||||||
sequence_length=128,
|
|
||||||
length=50,
|
|
||||||
min_input_id=1,
|
|
||||||
max_input_id=1000,
|
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
result = strategy.wrap_dataset(dummy)
|
|
||||||
|
|
||||||
assert len(result) == 50
|
|
||||||
assert len(result[0]["input_ids"]) == 128
|
|
||||||
assert len(result[0]["attention_mask"]) == 128
|
|
||||||
assert len(result[0]["labels"]) == 128
|
|
||||||
|
|
||||||
def test_attention_mask_all_ones(self):
|
|
||||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
result = strategy.wrap_dataset(dummy)
|
|
||||||
|
|
||||||
for row in result:
|
|
||||||
assert all(v == 1 for v in row["attention_mask"])
|
|
||||||
|
|
||||||
def test_labels_equal_input_ids(self):
|
|
||||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
result = strategy.wrap_dataset(dummy)
|
|
||||||
|
|
||||||
for row in result:
|
|
||||||
assert row["input_ids"] == row["labels"]
|
|
||||||
|
|
||||||
def test_input_id_range(self):
|
|
||||||
strategy = SyntheticDatasetStrategy(
|
|
||||||
sequence_length=64,
|
|
||||||
length=100,
|
|
||||||
min_input_id=500,
|
|
||||||
max_input_id=600,
|
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
result = strategy.wrap_dataset(dummy)
|
|
||||||
|
|
||||||
for row in result:
|
|
||||||
for token_id in row["input_ids"]:
|
|
||||||
assert 500 <= token_id < 600
|
|
||||||
|
|
||||||
def test_seed_reproducibility(self):
|
|
||||||
kwargs = dict(
|
|
||||||
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
|
|
||||||
)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
|
|
||||||
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
|
||||||
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
|
||||||
|
|
||||||
for r1, r2 in zip(result1, result2, strict=True):
|
|
||||||
assert r1["input_ids"] == r2["input_ids"]
|
|
||||||
|
|
||||||
def test_different_seeds_differ(self):
|
|
||||||
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
|
|
||||||
dummy = Dataset.from_dict({"text": [""]})
|
|
||||||
|
|
||||||
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
|
|
||||||
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
|
|
||||||
|
|
||||||
any_different = any(
|
|
||||||
r1["input_ids"] != r2["input_ids"]
|
|
||||||
for r1, r2 in zip(result1, result2, strict=True)
|
|
||||||
)
|
|
||||||
assert any_different
|
|
||||||
|
|
||||||
def test_load_function_with_ds_cfg(self):
|
|
||||||
tokenizer = MagicMock()
|
|
||||||
tokenizer.vocab_size = 32000
|
|
||||||
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
|
|
||||||
ds_cfg = {
|
|
||||||
"sequence_length": 256,
|
|
||||||
"length": 5,
|
|
||||||
"min_input_id": 10,
|
|
||||||
"max_input_id": 100,
|
|
||||||
"seed": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
|
|
||||||
assert isinstance(strategy, SyntheticDatasetStrategy)
|
|
||||||
assert strategy.sequence_length == 256
|
|
||||||
assert strategy.length == 5
|
|
||||||
assert strategy.min_input_id == 10
|
|
||||||
assert strategy.max_input_id == 100
|
|
||||||
|
|
||||||
def test_load_defaults_from_cfg(self):
|
|
||||||
tokenizer = MagicMock()
|
|
||||||
tokenizer.vocab_size = 32000
|
|
||||||
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
|
|
||||||
|
|
||||||
strategy = load(tokenizer, cfg, ds_cfg={})
|
|
||||||
assert strategy.sequence_length == 1024
|
|
||||||
assert strategy.max_input_id == 32000
|
|
||||||
assert strategy.length == 1000
|
|
||||||
|
|
||||||
def test_load_with_no_ds_cfg(self):
|
|
||||||
tokenizer = MagicMock()
|
|
||||||
tokenizer.vocab_size = 50000
|
|
||||||
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
|
|
||||||
|
|
||||||
strategy = load(tokenizer, cfg)
|
|
||||||
assert strategy.sequence_length == 2048
|
|
||||||
assert strategy.max_input_id == 50000
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user