Compare commits

..

3 Commits

Author SHA1 Message Date
Dan Saunders
9f30d3d33a reworking SP logic into composed handler 2025-04-04 02:25:00 +00:00
Dan Saunders
ce07081d6c doc updates; config fix 2025-04-01 20:35:10 +00:00
Dan Saunders
3ce43b6db9 simplifying trainer mixins and adding to rl trainers 2025-04-01 17:53:12 +00:00
33 changed files with 420 additions and 587 deletions

View File

@@ -52,12 +52,6 @@ jobs:
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -79,7 +73,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -1,38 +0,0 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
pip3 install -U --no-cache-dir pydantic==2.10.6

View File

@@ -510,8 +510,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
@@ -687,9 +686,10 @@ ddp_broadcast_buffers:
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'

View File

@@ -17,7 +17,6 @@ We currently support several common model architectures, including (but not limi
- `qwen2`
- `gemma`
- `gemma2`
- `gemma3`
<details>

View File

@@ -23,9 +23,10 @@ Use sequence parallelism when:
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
```
@@ -66,15 +67,16 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
This will train the Llama 3 8B model with 8192 context length, with each sequence split
into 4 subsequences of length 2048 across 4 GPUs.
## Sample Packing with Sequence Parallelism
@@ -86,12 +88,14 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2

View File

@@ -20,7 +20,7 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
adapter: lora
lora_model_dir:
sequence_len: 2048

View File

@@ -1,66 +0,0 @@
base_model: google/gemma-3-4b-it
strict: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -82,3 +82,6 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -16,7 +16,7 @@ transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
deepspeed==0.15.4
deepspeed==0.16.4
trl==0.16.0
optimum==1.16.2

View File

@@ -112,7 +112,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0"
__version__ = "0.8.0.dev0"

View File

@@ -256,7 +256,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -74,10 +74,8 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
**kwargs,
)
@@ -88,6 +86,13 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
)
parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg)

View File

@@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
@@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer.add_callback(callback)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -8,10 +8,11 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Literal
from typing import Any, Literal
import datasets
import torch
import torch.nn as nn
from datasets import Dataset
from torch.utils.data import (
BatchSampler,
@@ -25,12 +26,8 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -40,9 +37,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
class AxolotlTrainer(TrainerMixins, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -68,9 +63,7 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
@@ -131,7 +124,7 @@ class AxolotlTrainer(
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
@@ -167,7 +160,7 @@ class AxolotlTrainer(
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
@@ -239,7 +232,10 @@ class AxolotlTrainer(
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
dataloader = self.accelerator.prepare_data_loader(dataloader)
return dataloader
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
@@ -348,7 +344,57 @@ class AxolotlTrainer(
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping of inputs.
num_items_in_batch: The number of items in the batch.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
@override
def compute_loss(

View File

@@ -1,14 +1,10 @@
"""
DPO Specific Strategy for training
"""
"""DPO Specific Strategy for training"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
"""Strategy for DPO training"""
@classmethod
def get_trainer_class(cls):

View File

@@ -1,6 +1,4 @@
"""
Axolotl specific DPO args
"""
"""Axolotl specific DPO args"""
from dataclasses import dataclass
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
"""DPO config for DPO training"""

View File

@@ -1,10 +1,7 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for axolotl"""
import gc
from functools import wraps
from typing import Any, Dict, Union
from typing import Any
import torch
from peft.optimizers import create_loraplus_optimizer
@@ -13,7 +10,8 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,18 +21,18 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
"""Extend the base DPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
def create_optimizer(self):
# pylint: disable=duplicate-code
@@ -88,7 +86,7 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
) -> dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
@@ -117,10 +115,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
inputs: dict[str, torch.Tensor | Any | None],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
return super().training_step(model, inputs, num_items_in_batch)

View File

@@ -1,6 +1,4 @@
"""
Axolotl GRPO trainer
"""
"""Axolotl GRPO trainer"""
from contextlib import nullcontext
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import TrainerMixins
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -0,0 +1,3 @@
"""Init for trainer handlers"""
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler

View File

@@ -0,0 +1,123 @@
"""Handler class for sequence parallel trainer logic"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DistributedSampler
class SequenceParallelHandler:
"""
Handler class that encapsulates sequence parallelism functionality.
This replaces the SequenceParallelMixin with a composition-based approach.
"""
def __init__(self, args=None):
"""
Initialize the sequence parallel handler.
Args:
args: The arguments object containing sequence parallelism settings.
"""
self.args = args
self.ring_attn_group = None
# Set up sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
from ring_flash_attn import update_ring_flash_attn_params
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.update_ring_flash_attn_params = update_ring_flash_attn_params
self.ring_attn_group = get_ring_attn_group()
def create_sequence_parallel_sampler(
self,
dataset,
shuffle=True,
is_eval=False,
):
"""
Helper method to create sampler for sequence parallelism (SP).
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset):
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset):
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -3,7 +3,12 @@
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin
):
"""Stub class combining all mixins for Axolotl trainers."""

View File

@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
"""Mixin for method override to load RNG states from a checkpoint"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`

View File

@@ -1,4 +1,5 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
# TODO(Dan): remove
import logging
from typing import Any
@@ -7,7 +8,6 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
def _get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
@@ -130,53 +130,3 @@ class SequenceParallelMixin:
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -13,11 +13,10 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from axolotl.core.trainers.mixins import TrainerMixins
class TRLPPOTrainer(PPOTrainer):
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
"""Extend the base ORPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "orpo"]
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
"""Extend the base KTOTrainer for axolotl helpers"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
"""Extend the base CPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "cpo"]
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
"""Extend the base RewardTrainer for axolotl helpers"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
"""Extend the base trl.PRMTrainer for axolotl helpers"""
tag_names = ["axolotl", "prm"]

View File

@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
"""Mixin class for the Axolotl training args."""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(

View File

@@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging()
LOG = get_logger(__name__)
@@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
ring_attn_group: Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
"""
Patch flash attention a second time to handle batched data. This is a hack to
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
specified in the Axolotl config.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
# Store the original flash attention function
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
def sequential_batch_flash_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
sliding_window: int | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# Check if we have a batch dimension > 1
batch_size = query.shape[0]
if batch_size <= 1:
return original_flash_attention(
module,
query,
key,
value,
attention_mask,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
# Process each item in the batch separately
outputs = []
for i in range(batch_size):
# Extract single batch item
q_item = query[i:i+1]
k_item = key[i:i+1]
v_item = value[i:i+1]
# Handle attention mask - it might be None or have different shapes
mask_item = None
if attention_mask is not None:
# The mask could have different formats depending on implementation
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
mask_item = attention_mask[i:i+1]
else:
# For broadcast masks that don't have a batch dimension
mask_item = attention_mask
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = q_item.shape[0]
seq_len = q_item.shape[2]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
# Call the original function for a single batch item
output, _ = original_flash_attention(
module,
q_item,
k_item,
v_item,
mask_item,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
outputs.append(output)
dist.barrier()
# Concatenate results along batch dimension
concatenated_output = torch.cat(outputs, dim=0)
return concatenated_output, None
# Replace the original function with our sequential version
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)

View File

@@ -1,238 +0,0 @@
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token is OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
if attention_mask is not None:
# Get the shifted attention mask
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
logits.device
) # +1 for shift
# Filter logits and labels based on attention mask
valid_indices = shift_attention_mask != 0
filtered_logits = logits[..., :-1, :][valid_indices]
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
loss = self.loss_function(
logits=filtered_logits,
labels=None, # we pass shift_labels
shift_labels=filtered_labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
else:
# Standard case without filtering
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma3conditionalgeneration_forward():
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForConditionalGeneration,
)
Gemma3ForConditionalGeneration.forward = new_forward

View File

@@ -252,38 +252,12 @@ def apply_lora_kernel_patches(
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = None
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
if hasattr(text_config, "hidden_act"):
activation = text_config.hidden_act
elif hasattr(text_config, "hidden_activation"):
activation = text_config.hidden_activation
# map activation to supported activation
if "gelu" in activation:
# gemma3 uses gelu_pytorch_tanh
activation = "gelu"
activation = model.config.hidden_act
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = []
# check for multimodal models first
if hasattr(model, "language_model"):
layers = model.language_model.model.layers
elif hasattr(model, "model"):
layers = model.model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
# Patch each layer
for layer in layers:
for layer in model.model.model.layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(

View File

@@ -535,15 +535,6 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,
)
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
@@ -1360,9 +1351,7 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
"""Load a model for a given configuration and tokenizer."""
model_loader = ModelLoader(
cfg,
tokenizer,
@@ -1371,12 +1360,16 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return model_loader.load_model()
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
@@ -1389,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False):
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PreTrainedModel, PeftConfig | None]:
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
@@ -1414,7 +1408,7 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(model):
def find_all_linear_names(model: PreTrainedModel):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():

View File

@@ -1290,5 +1290,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)
return self