Compare commits
13 Commits
sp-rl
...
lora-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
700409be6f | ||
|
|
64d8035f50 | ||
|
|
5249e98058 | ||
|
|
3877c5c69d | ||
|
|
adb593abac | ||
|
|
a0117c9bce | ||
|
|
e6cfb093d2 | ||
|
|
7abc71dc0b | ||
|
|
45bf634d17 | ||
|
|
80ba4b69f1 | ||
|
|
0bfa180f7d | ||
|
|
9e22c4ca6a | ||
|
|
990b5896bc |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -52,6 +52,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: next
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -73,7 +79,7 @@ jobs:
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
38
docker/Dockerfile-base-next
Normal file
38
docker/Dockerfile-base-next
Normal file
@@ -0,0 +1,38 @@
|
||||
ARG CUDA_VERSION="12.8.1"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
pip3 install -U --no-cache-dir pydantic==2.10.6
|
||||
@@ -510,7 +510,8 @@ train_on_inputs: false
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
@@ -686,10 +687,9 @@ ddp_broadcast_buffers:
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
|
||||
@@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
- `gemma3`
|
||||
|
||||
<details>
|
||||
|
||||
|
||||
@@ -23,10 +23,9 @@ Use sequence parallelism when:
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
@@ -67,16 +66,15 @@ sequence_len: 8192
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
||||
into 4 subsequences of length 2048 across 4 GPUs.
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
@@ -88,14 +86,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
strict: false
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -2,6 +2,8 @@ base_model: google/gemma-3-4b-it
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
@@ -20,7 +22,7 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
@@ -82,6 +82,3 @@ deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
@@ -16,7 +16,7 @@ transformers==4.50.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.5.0
|
||||
deepspeed==0.16.4
|
||||
deepspeed==0.15.4
|
||||
trl==0.16.0
|
||||
|
||||
optimum==1.16.2
|
||||
|
||||
2
setup.py
2
setup.py
@@ -112,7 +112,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.8.0.dev0"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
@@ -256,7 +256,7 @@ def do_cli(
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -74,8 +74,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -86,13 +88,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||
)
|
||||
|
||||
parsed_cfg.load_in_4bit = False
|
||||
parsed_cfg.load_in_8bit = False
|
||||
parsed_cfg.flash_attention = False
|
||||
parsed_cfg.deepspeed = None
|
||||
parsed_cfg.fsdp = None
|
||||
parsed_cfg.fsdp_config = None
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
|
||||
|
||||
|
||||
@@ -1043,10 +1043,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
@@ -1165,7 +1161,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
@@ -1183,3 +1178,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
return dpo_trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# build PPOConfig
|
||||
pass
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||
from axolotl.core.trainers.trl import (
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -8,11 +8,10 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
@@ -26,8 +25,12 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -37,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -63,7 +68,9 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
@@ -124,7 +131,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
@@ -160,7 +167,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
@@ -232,10 +239,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
@@ -344,57 +348,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
num_items_in_batch: The number of items in the batch.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""DPO Specific Strategy for training"""
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""Strategy for DPO training"""
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Axolotl specific DPO args"""
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -9,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""DPO config for DPO training"""
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""DPO trainer for axolotl"""
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
@@ -10,8 +13,7 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -21,18 +23,18 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -86,7 +88,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> dict:
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
@@ -115,9 +117,10 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any | None],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
||||
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Axolotl GRPO trainer"""
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -6,14 +8,16 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Init for trainer handlers"""
|
||||
|
||||
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Handler class for sequence parallel trainer logic"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
|
||||
class SequenceParallelHandler:
|
||||
"""
|
||||
Handler class that encapsulates sequence parallelism functionality.
|
||||
This replaces the SequenceParallelMixin with a composition-based approach.
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
"""
|
||||
Initialize the sequence parallel handler.
|
||||
|
||||
Args:
|
||||
args: The arguments object containing sequence parallelism settings.
|
||||
"""
|
||||
self.args = args
|
||||
self.ring_attn_group = None
|
||||
|
||||
# Set up sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset,
|
||||
shuffle=True,
|
||||
is_eval=False,
|
||||
):
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self, dataset):
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset):
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
@@ -3,12 +3,7 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TrainerMixins(
|
||||
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
||||
):
|
||||
"""Stub class combining all mixins for Axolotl trainers."""
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
|
||||
@@ -21,7 +21,9 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||
"""
|
||||
mixin for method override to load RNG states from a checkpoint
|
||||
"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
# TODO(Dan): remove
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -8,6 +7,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self, dataset) -> Sampler | None:
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset.
|
||||
dataset: The training dataset
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
@@ -130,3 +130,53 @@ class SequenceParallelMixin:
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
|
||||
@@ -13,10 +13,11 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
@@ -74,8 +75,10 @@ class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
@@ -152,14 +155,18 @@ class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
@@ -238,13 +245,17 @@ class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
@@ -12,7 +12,9 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""Mixin class for the Axolotl training args."""
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
|
||||
@@ -6,22 +6,11 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from accelerate.logging import get_logger
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -43,120 +32,12 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
ring_attn_group: Process group for ring attention.
|
||||
Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
|
||||
"""
|
||||
Patch flash attention a second time to handle batched data. This is a hack to
|
||||
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
|
||||
specified in the Axolotl config.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
"""
|
||||
# Store the original flash attention function
|
||||
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
|
||||
def sequential_batch_flash_attention(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
dropout: float = 0.0,
|
||||
scaling: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
# Check if we have a batch dimension > 1
|
||||
batch_size = query.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return original_flash_attention(
|
||||
module,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Process each item in the batch separately
|
||||
outputs = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract single batch item
|
||||
q_item = query[i:i+1]
|
||||
k_item = key[i:i+1]
|
||||
v_item = value[i:i+1]
|
||||
|
||||
# Handle attention mask - it might be None or have different shapes
|
||||
mask_item = None
|
||||
if attention_mask is not None:
|
||||
# The mask could have different formats depending on implementation
|
||||
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
|
||||
mask_item = attention_mask[i:i+1]
|
||||
else:
|
||||
# For broadcast masks that don't have a batch dimension
|
||||
mask_item = attention_mask
|
||||
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = q_item.shape[0]
|
||||
seq_len = q_item.shape[2]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
# Call the original function for a single batch item
|
||||
output, _ = original_flash_attention(
|
||||
module,
|
||||
q_item,
|
||||
k_item,
|
||||
v_item,
|
||||
mask_item,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Concatenate results along batch dimension
|
||||
concatenated_output = torch.cat(outputs, dim=0)
|
||||
return concatenated_output, None
|
||||
|
||||
# Replace the original function with our sequential version
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
@@ -217,4 +98,3 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
|
||||
|
||||
238
src/axolotl/monkeypatch/gemma3.py
Normal file
238
src/axolotl/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def new_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id with PAD if the image token is OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
)
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where(
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if attention_mask is not None:
|
||||
# Get the shifted attention mask
|
||||
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
|
||||
logits.device
|
||||
) # +1 for shift
|
||||
|
||||
# Filter logits and labels based on attention mask
|
||||
valid_indices = shift_attention_mask != 0
|
||||
filtered_logits = logits[..., :-1, :][valid_indices]
|
||||
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
|
||||
|
||||
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
|
||||
|
||||
loss = self.loss_function(
|
||||
logits=filtered_logits,
|
||||
labels=None, # we pass shift_labels
|
||||
shift_labels=filtered_labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
# Standard case without filtering
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma3conditionalgeneration_forward():
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
Gemma3ForConditionalGeneration.forward = new_forward
|
||||
@@ -252,12 +252,38 @@ def apply_lora_kernel_patches(
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
activation = None
|
||||
text_config = (
|
||||
model.config.get_text_config()
|
||||
if hasattr(model.config, "get_text_config")
|
||||
else model.config
|
||||
)
|
||||
if hasattr(text_config, "hidden_act"):
|
||||
activation = text_config.hidden_act
|
||||
elif hasattr(text_config, "hidden_activation"):
|
||||
activation = text_config.hidden_activation
|
||||
|
||||
# map activation to supported activation
|
||||
if "gelu" in activation:
|
||||
# gemma3 uses gelu_pytorch_tanh
|
||||
activation = "gelu"
|
||||
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
layers = []
|
||||
# check for multimodal models first
|
||||
if hasattr(model, "language_model"):
|
||||
layers = model.language_model.model.layers
|
||||
elif hasattr(model, "model"):
|
||||
layers = model.model.model.layers
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
for layer in layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
|
||||
@@ -78,6 +78,7 @@ def resolve_dtype(cfg):
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = False
|
||||
|
||||
|
||||
@@ -535,6 +535,15 @@ class ModelLoader:
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
# patch gemma3 conditional generation forward before loading plugins
|
||||
# as it could be overridden by plugins
|
||||
if self.cfg.model_config_type == "gemma3":
|
||||
from axolotl.monkeypatch.gemma3 import (
|
||||
patch_gemma3conditionalgeneration_forward,
|
||||
)
|
||||
|
||||
patch_gemma3conditionalgeneration_forward()
|
||||
|
||||
# load any patches from plugins
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
@@ -1351,7 +1360,9 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
model_loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1360,16 +1371,12 @@ def load_model(
|
||||
reference_model=reference_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model_loader.load_model()
|
||||
|
||||
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
adapter: str | None,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
@@ -1382,9 +1389,8 @@ def load_adapter(
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
|
||||
|
||||
def load_llama_adapter(
|
||||
model: PreTrainedModel, cfg: DictDefault
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
@@ -1408,7 +1414,7 @@ def load_llama_adapter(
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def find_all_linear_names(model: PreTrainedModel):
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -1224,17 +1224,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
)
|
||||
if is_deepspeed:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1290,3 +1285,5 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
LOG.warning(
|
||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user