Compare commits
8 Commits
scattermoe
...
fix/cp-was
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
255c5b90ca | ||
|
|
038ffe3f26 | ||
|
|
c13cb7c853 | ||
|
|
b3823cc6b0 | ||
|
|
113d275bd9 | ||
|
|
7920fe74ec | ||
|
|
1fc86d5295 | ||
|
|
bb483ad4c4 |
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
11
cicd/cicd.sh
11
cicd/cicd.sh
@@ -3,11 +3,12 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
hf download "NousResearch/Meta-Llama-3-8B"
|
||||
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
hf download "microsoft/Phi-4-reasoning"
|
||||
hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
|
||||
@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
|
||||
@@ -20,6 +20,7 @@ format:
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [Qwen3.5](#sec-qwen3-5)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
@@ -191,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3.5 {#sec-qwen3-5}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
chat_template: qwen3_5
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
|
||||
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: nvidia/Nemotron-Mini-4B-Instruct
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/nemotron-mini-4b-qlora
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
special_tokens:
|
||||
59
examples/qwen3.5/27b-fft.yaml
Normal file
59
examples/qwen3.5/27b-fft.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
# Freeze vision encoder
|
||||
unfrozen_parameters:
|
||||
- model\.language_model\..*
|
||||
- lm_head\..*
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
@@ -0,0 +1,49 @@
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,10 +1,6 @@
|
||||
base_model: Qwen/Qwen3.5-7B
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
|
||||
# Vision and text tokens are processed together by the same transformer layers.
|
||||
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
|
||||
|
||||
# These 3 lines are required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
@@ -1,15 +1,20 @@
|
||||
# Finetune Qwen3.5 with Axolotl
|
||||
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||
|
||||
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type |
|
||||
|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
|
||||
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -29,23 +34,31 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||
|
||||
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
|
||||
axolotl train examples/qwen3.5/27b-fft.yaml
|
||||
|
||||
# MoE 35B-A3B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||
|
||||
# MoE 122B-A10B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||
|
||||
# 7B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/7b-lora-vision.yaml
|
||||
# 9B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/9b-lora-vision.yaml
|
||||
|
||||
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
|
||||
axolotl train examples/qwen3.5/9b-fft-vision.yaml
|
||||
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
|
||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
)
|
||||
|
||||
@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
|
||||
# config reflects this regardless of how the model was instantiated.
|
||||
if (
|
||||
self.cfg.reward_model
|
||||
and getattr(self.model.config, "num_labels", None) != 1
|
||||
):
|
||||
self.model.config.num_labels = 1
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -199,24 +199,30 @@ def _estimate_register_pressure(
|
||||
num_warps: int,
|
||||
*tile_sizes: tuple[int, int],
|
||||
) -> float:
|
||||
"""Estimate per-thread register count from live tile sizes.
|
||||
"""Rough estimate of per-thread register footprint from live tile sizes.
|
||||
|
||||
Each tile of shape (rows, cols) requires rows*cols elements distributed
|
||||
across 32 threads per warp, but each thread in the warp holds a fragment.
|
||||
For Triton GEMM-style kernels, the register footprint per thread is
|
||||
approximately sum(rows * cols) / 32 for each live tile, plus ~40 for
|
||||
scalar overhead (loop counters, pointers, masks, etc.).
|
||||
This is a heuristic, NOT an accurate register count. Triton uses tensor
|
||||
core MMA fragments that pack multiple elements per register, and can spill
|
||||
to local memory when the hardware limit (255 regs/thread) is exceeded.
|
||||
|
||||
The estimate is used to prune only truly extreme configs that would cause
|
||||
excessive spilling or compilation failures. The threshold is set high
|
||||
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
|
||||
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
|
||||
(est ~520) work fine in practice via spilling.
|
||||
|
||||
Returns estimated registers per thread.
|
||||
"""
|
||||
# Each thread in a warp holds 1/32 of the tile elements
|
||||
# Each thread in a warp holds ~1/32 of the tile elements
|
||||
tile_regs = sum(r * c for r, c in tile_sizes) / 32
|
||||
scalar_overhead = 40
|
||||
return tile_regs + scalar_overhead
|
||||
|
||||
|
||||
# Maximum registers per thread on NVIDIA GPUs
|
||||
_MAX_REGS_PER_THREAD = 255
|
||||
# Soft limit for register pressure pruning. Only prune configs with extreme
|
||||
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
|
||||
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
|
||||
_MAX_REGS_SOFT_LIMIT = 1024
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -419,7 +425,7 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
@@ -999,7 +1005,7 @@ def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
(block_n, block_r), # b tile
|
||||
(block_r, block_k), # a tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
@@ -1332,7 +1338,7 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
(block_n, block_r), # b tile
|
||||
(block_m, block_r), # xa intermediate
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
@@ -1581,7 +1587,7 @@ def _prune_split_configs(configs, named_args, **kwargs):
|
||||
(block_m, block_dim), # other tile
|
||||
(block_r, BLOCK_INNER), # lora weight
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
if smem <= smem_cap - _SMEM_SLACK:
|
||||
|
||||
@@ -640,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
|
||||
# This is 65x fewer FLOPs than materializing B@A into [out, in]
|
||||
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
@@ -648,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
@@ -656,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
@@ -819,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
|
||||
del W
|
||||
|
||||
A, B = A.to(dtype), B.to(dtype)
|
||||
dX += s * dY @ B @ A
|
||||
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
|
||||
dX.addmm_(torch.mm(dY, B), A, alpha=s)
|
||||
|
||||
# W, b, W_quant, A, B, s
|
||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
@@ -133,13 +133,6 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches right after model build, before post-load setup."""
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
@@ -78,30 +78,29 @@ def patch_parallelism_config():
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from transformers import Trainer
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@contextlib.contextmanager
|
||||
def _noop_cp_context(
|
||||
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
|
||||
):
|
||||
yield
|
||||
|
||||
self._cp_context = _noop_cp_context
|
||||
return args
|
||||
|
||||
def _noop_prepare_context_parallel_inputs(self, model, inputs):
|
||||
return contextlib.nullcontext, inputs
|
||||
|
||||
# prevent double CP partition
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
# remove unneeded calculation upstream
|
||||
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
|
||||
|
||||
@@ -51,6 +51,29 @@ QKV_PATCHES = [
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
(
|
||||
"""
|
||||
query_states, gate = torch.chunk(
|
||||
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states, gate = torch.chunk(
|
||||
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
return pretrained_model.language_model.layers
|
||||
if hasattr(pretrained_model, "model"):
|
||||
if hasattr(pretrained_model.model, "language_model"):
|
||||
return pretrained_model.model.language_model.layers
|
||||
return pretrained_model.model.layers
|
||||
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -59,6 +59,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"ministral3",
|
||||
"mistral4",
|
||||
"afmoe",
|
||||
"nemotron",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
||||
return
|
||||
|
||||
try:
|
||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
||||
return
|
||||
|
||||
if GUARD_PATTERN not in original_source:
|
||||
LOG.warning(
|
||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
||||
"skipping FlashAttention context parallelism patch"
|
||||
)
|
||||
return
|
||||
|
||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
||||
patched_source, _ = detab_code(patched_source)
|
||||
patched_source = patched_source.replace(
|
||||
"def _prepare_context_parallel_inputs(",
|
||||
"def axolotl_prepare_context_parallel_inputs(",
|
||||
1,
|
||||
)
|
||||
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# import symbols referenced in the method so exec can succeed
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Use a separate namespace to capture the exec'd function
|
||||
namespace = {}
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
|
||||
exec(patched_source, namespace)
|
||||
|
||||
# Explicitly get the function from the namespace
|
||||
axolotl_prepare_context_parallel_inputs = namespace[
|
||||
"axolotl_prepare_context_parallel_inputs"
|
||||
]
|
||||
Trainer._original_prepare_context_parallel_inputs = (
|
||||
Trainer._prepare_context_parallel_inputs
|
||||
)
|
||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||
LOG.debug(
|
||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||
)
|
||||
@@ -253,6 +253,23 @@ class TrainingValidationMixin:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_reward_model_defaults(cls, data):
|
||||
if data.get("reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 1
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForSequenceClassification"
|
||||
|
||||
if data.get("process_reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 2
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForTokenClassification"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
|
||||
@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
|
||||
"cfg_string",
|
||||
[
|
||||
"sft_cfg",
|
||||
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
|
||||
"rm_cfg",
|
||||
"prm_cfg",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -4,8 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||
@with_temp_dir
|
||||
def test_lora_gptq_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"gptq": True,
|
||||
"gptq_disable_exllama": True,
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"max_steps": 20,
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Tests for the HF Trainer context parallel patch."""
|
||||
|
||||
import pytest
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
GUARD_PATTERN,
|
||||
PATCHED_GUARD,
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_trainer_prepare_method():
|
||||
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
||||
original_method = getattr(
|
||||
Trainer,
|
||||
"_original_prepare_context_parallel_inputs",
|
||||
Trainer._prepare_context_parallel_inputs,
|
||||
)
|
||||
patched_attr_present = hasattr(
|
||||
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
Trainer._prepare_context_parallel_inputs = original_method
|
||||
if patched_attr_present:
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||
|
||||
|
||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||
# Ensure we start from the unpatched method
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
Trainer._prepare_context_parallel_inputs = (
|
||||
Trainer._original_prepare_context_parallel_inputs
|
||||
)
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
patched_method = Trainer._prepare_context_parallel_inputs
|
||||
assert patched_method is not None
|
||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||
|
||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||
assert GUARD_PATTERN not in source
|
||||
assert PATCHED_GUARD in source
|
||||
|
||||
|
||||
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
||||
"""Calling the patch twice should leave the same patched function in place."""
|
||||
patch_prepare_context_parallel_inputs()
|
||||
first_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
second_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
assert first_patched is second_patched
|
||||
@@ -277,6 +277,34 @@ class TestValidation(BaseValidation):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||
|
||||
def test_reward_model_defaults(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"reward_model": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.num_labels == 1
|
||||
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
|
||||
|
||||
def test_process_reward_model_defaults(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"process_reward_model": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.num_labels == 2
|
||||
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
|
||||
|
||||
def test_model_revision_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user