Compare commits

..

13 Commits

Author SHA1 Message Date
Dan Saunders
700409be6f removing deepspeed guard for LoRA Triton kernels 2025-04-03 16:44:45 +00:00
NanoCode012
64d8035f50 fix(example): align example to correct adapter (#2478)
* fix(example): align example to correct adapter

* fix: add missing load in 4 bit
2025-04-03 08:48:14 -04:00
Wing Lian
5249e98058 add additional tf32 opt for cudnn (#2477) [skip ci] 2025-04-03 08:47:52 -04:00
Wing Lian
3877c5c69d set release version 0.8.0 (#2476)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* set release version 0.8.0

* make sure to include ring-flash-attn in docker image build
2025-04-02 09:50:56 -04:00
NanoCode012
adb593abac fix: document offload gradient_checkpointing option (#2475) 2025-04-02 09:35:42 -04:00
NanoCode012
a0117c9bce fix: separate gemma3 text and vision example config (#2471) [skip ci]
* fix: separate gemma3 text and vision example config

* fix: update to use a text-only dataset

* fix: typo
2025-04-02 09:35:29 -04:00
NanoCode012
e6cfb093d2 fix: disable SP during merge (#2470) [skip ci] 2025-04-02 09:35:00 -04:00
NanoCode012
7abc71dc0b fix: gemma3 loss in forward pass (#2473) [skip ci]
* fix: gemma3 loss in forward pass

* fix: lint

* fix: move patch before plugins

* Update src/axolotl/monkeypatch/gemma3.py

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-04-02 09:34:41 -04:00
NanoCode012
45bf634d17 feat: add support for multimodal in lora kernels (#2472) [skip ci]
* feat: add support for multimodal in lora kernels

* fix: improve multimodal checks

* fix: add fallback for model config

* chor: add gemma3 to docs
2025-04-02 09:33:46 -04:00
NanoCode012
80ba4b69f1 fix: pydantic warning validator not returning self (#2474) 2025-04-02 07:40:49 -04:00
Wing Lian
0bfa180f7d torch 2.7.0 base image for testing (#2467) 2025-04-01 15:38:26 -04:00
NanoCode012
9e22c4ca6a fix: set rl=None during inference (#2463) 2025-04-01 12:25:53 -04:00
NanoCode012
990b5896bc fix: downgrade deepspeed to fix grad checkpoint oom (#2465) [skip ci] 2025-04-01 12:25:05 -04:00
34 changed files with 592 additions and 425 deletions

View File

@@ -52,6 +52,12 @@ jobs:
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -73,7 +79,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -0,0 +1,38 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
pip3 install -U --no-cache-dir pydantic==2.10.6

View File

@@ -510,7 +510,8 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
@@ -686,10 +687,9 @@ ddp_broadcast_buffers:
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'

View File

@@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi
- `qwen2`
- `gemma`
- `gemma2`
- `gemma3`
<details>

View File

@@ -23,10 +23,9 @@ Use sequence parallelism when:
To enable sequence parallelism, add the following to your configuration file:
```yaml
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
```
@@ -67,16 +66,15 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
```
This will train the Llama 3 8B model with 8192 context length, with each sequence split
into 4 subsequences of length 2048 across 4 GPUs.
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
@@ -88,14 +86,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-4b-it
strict: false
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -2,6 +2,8 @@ base_model: google/gemma-3-4b-it
processor_type: AutoProcessor
strict: false
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
@@ -20,7 +22,7 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
adapter: qlora
lora_model_dir:
sequence_len: 2048

View File

@@ -82,6 +82,3 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -16,7 +16,7 @@ transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
deepspeed==0.16.4
deepspeed==0.15.4
trl==0.16.0
optimum==1.16.2

View File

@@ -112,7 +112,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed==0.15.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0.dev0"
__version__ = "0.8.0"

View File

@@ -256,7 +256,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -74,8 +74,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
**kwargs,
)
@@ -86,13 +88,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
)
parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg)

View File

@@ -1043,10 +1043,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1165,7 +1161,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
@@ -1183,3 +1178,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer.add_callback(callback)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import
# flake8: noqa
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import (
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -8,11 +8,10 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Literal
from typing import Literal
import datasets
import torch
import torch.nn as nn
from datasets import Dataset
from torch.utils.data import (
BatchSampler,
@@ -26,8 +25,12 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -37,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(TrainerMixins, Trainer):
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -63,7 +68,9 @@ class AxolotlTrainer(TrainerMixins, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
@@ -124,7 +131,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
@@ -160,7 +167,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
@@ -232,10 +239,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
return dataloader
# Otherwise prepare with accelerator
dataloader = self.accelerator.prepare_data_loader(dataloader)
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
@@ -344,57 +348,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping of inputs.
num_items_in_batch: The number of items in the batch.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(

View File

@@ -1,10 +1,14 @@
"""DPO Specific Strategy for training"""
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""Strategy for DPO training"""
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):

View File

@@ -1,4 +1,6 @@
"""Axolotl specific DPO args"""
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
@@ -9,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""DPO config for DPO training"""
"""
DPO config for DPO training
"""

View File

@@ -1,7 +1,10 @@
"""DPO trainer for axolotl"""
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
@@ -10,8 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -21,18 +23,18 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
"""Extend the base DPOTrainer for axolotl helpers"""
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
def create_optimizer(self):
# pylint: disable=duplicate-code
@@ -86,7 +88,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> dict:
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
@@ -115,9 +117,10 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any | None],
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
return super().training_step(model, inputs, num_items_in_batch)
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss

View File

@@ -1,4 +1,6 @@
"""Axolotl GRPO trainer"""
"""
Axolotl GRPO trainer
"""
from contextlib import nullcontext
@@ -6,14 +8,16 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -1,3 +0,0 @@
"""Init for trainer handlers"""
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler

View File

@@ -1,123 +0,0 @@
"""Handler class for sequence parallel trainer logic"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DistributedSampler
class SequenceParallelHandler:
"""
Handler class that encapsulates sequence parallelism functionality.
This replaces the SequenceParallelMixin with a composition-based approach.
"""
def __init__(self, args=None):
"""
Initialize the sequence parallel handler.
Args:
args: The arguments object containing sequence parallelism settings.
"""
self.args = args
self.ring_attn_group = None
# Set up sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
from ring_flash_attn import update_ring_flash_attn_params
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.update_ring_flash_attn_params = update_ring_flash_attn_params
self.ring_attn_group = get_ring_attn_group()
def create_sequence_parallel_sampler(
self,
dataset,
shuffle=True,
is_eval=False,
):
"""
Helper method to create sampler for sequence parallelism (SP).
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset):
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset):
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -3,12 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin
):
"""Stub class combining all mixins for Axolotl trainers."""
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -21,7 +21,9 @@ LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""Mixin for method override to load RNG states from a checkpoint"""
"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`

View File

@@ -1,5 +1,4 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
# TODO(Dan): remove
import logging
from typing import Any
@@ -8,6 +7,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset) -> Sampler | None:
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
@@ -130,3 +130,53 @@ class SequenceParallelMixin:
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -13,10 +13,11 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
@@ -74,8 +75,10 @@ class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
)
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
"""Extend the base ORPOTrainer for axolotl helpers"""
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
@@ -152,14 +155,18 @@ class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
"""Extend the base KTOTrainer for axolotl helpers"""
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
"""Extend the base CPOTrainer for axolotl helpers"""
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
@@ -238,13 +245,17 @@ class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
"""Extend the base RewardTrainer for axolotl helpers"""
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
"""Extend the base trl.PRMTrainer for axolotl helpers"""
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -12,7 +12,9 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""Mixin class for the Axolotl training args."""
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(

View File

@@ -6,22 +6,11 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging()
LOG = get_logger(__name__)
@@ -43,120 +32,12 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
Setter for ring attention group on this rank.
Args:
ring_attn_group: Process group for ring attention.
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
"""
Patch flash attention a second time to handle batched data. This is a hack to
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
specified in the Axolotl config.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
# Store the original flash attention function
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
def sequential_batch_flash_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
sliding_window: int | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# Check if we have a batch dimension > 1
batch_size = query.shape[0]
if batch_size <= 1:
return original_flash_attention(
module,
query,
key,
value,
attention_mask,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
# Process each item in the batch separately
outputs = []
for i in range(batch_size):
# Extract single batch item
q_item = query[i:i+1]
k_item = key[i:i+1]
v_item = value[i:i+1]
# Handle attention mask - it might be None or have different shapes
mask_item = None
if attention_mask is not None:
# The mask could have different formats depending on implementation
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
mask_item = attention_mask[i:i+1]
else:
# For broadcast masks that don't have a batch dimension
mask_item = attention_mask
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = q_item.shape[0]
seq_len = q_item.shape[2]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
# Call the original function for a single batch item
output, _ = original_flash_attention(
module,
q_item,
k_item,
v_item,
mask_item,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
outputs.append(output)
dist.barrier()
# Concatenate results along batch dimension
concatenated_output = torch.cat(outputs, dim=0)
return concatenated_output, None
# Replace the original function with our sequential version
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -217,4 +98,3 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)

View File

@@ -0,0 +1,238 @@
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token is OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
if attention_mask is not None:
# Get the shifted attention mask
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
logits.device
) # +1 for shift
# Filter logits and labels based on attention mask
valid_indices = shift_attention_mask != 0
filtered_logits = logits[..., :-1, :][valid_indices]
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
loss = self.loss_function(
logits=filtered_logits,
labels=None, # we pass shift_labels
shift_labels=filtered_labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
else:
# Standard case without filtering
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma3conditionalgeneration_forward():
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForConditionalGeneration,
)
Gemma3ForConditionalGeneration.forward = new_forward

View File

@@ -252,12 +252,38 @@ def apply_lora_kernel_patches(
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = model.config.hidden_act
activation = None
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
if hasattr(text_config, "hidden_act"):
activation = text_config.hidden_act
elif hasattr(text_config, "hidden_activation"):
activation = text_config.hidden_activation
# map activation to supported activation
if "gelu" in activation:
# gemma3 uses gelu_pytorch_tanh
activation = "gelu"
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = []
# check for multimodal models first
if hasattr(model, "language_model"):
layers = model.language_model.model.layers
elif hasattr(model, "model"):
layers = model.model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
# Patch each layer
for layer in model.model.model.layers:
for layer in layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(

View File

@@ -78,6 +78,7 @@ def resolve_dtype(cfg):
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
cfg.fp16 = False

View File

@@ -535,6 +535,15 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,
)
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
@@ -1351,7 +1360,9 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""Load a model for a given configuration and tokenizer."""
"""
Load a model for a given configuration and tokenizer.
"""
model_loader = ModelLoader(
cfg,
tokenizer,
@@ -1360,16 +1371,12 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return model_loader.load_model()
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel, PeftConfig | None]:
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
@@ -1382,9 +1389,8 @@ def load_adapter(
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PreTrainedModel, PeftConfig | None]:
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
@@ -1408,7 +1414,7 @@ def load_llama_adapter(
return model, peft_config
def find_all_linear_names(model: PreTrainedModel):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():

View File

@@ -1224,17 +1224,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None
if capabilities and capabilities.get("n_gpu", 0) > 1:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
)
if is_deepspeed:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
)
return data
@model_validator(mode="before")
@@ -1290,3 +1285,5 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)
return self