Compare commits

...

17 Commits

Author SHA1 Message Date
Dan Saunders
30981328fc draft config for devstral 2025-05-23 20:04:21 +00:00
Dan Saunders
b5f1e53a0f models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
2025-05-23 15:51:11 -04:00
Dan Saunders
8cde256db2 Remove unused const (#2714)
* remove unused const

* accidentally commited benchmark plot
2025-05-23 12:27:38 -04:00
Dan Saunders
5f8f817200 SP context manager update (#2699)
* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* coderabbit commits, cleanup, remove dead code

* fix

* move ring attn patching to sp ctx manager

* lint

* lint

* test fix

* test fix
2025-05-22 11:18:32 -04:00
NanoCode012
aa0492c366 feat: do not find turn indices if turn is not trainable (#2696)
* feat: do not find turn indices if turn is not trainable

* fix: handle edge case where train on eos/eot is all

* fix: improve warning message
2025-05-22 19:19:59 +07:00
NanoCode012
798b5f5cfd fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls

* feat(test): add test to catch trainer_cls is not None
2025-05-22 19:19:12 +07:00
NanoCode012
1c83a1a020 feat(doc): clarify minimum pytorch and cuda to use blackwell (#2704) [skip ci] 2025-05-22 19:18:27 +07:00
Dan Saunders
6aa41740df SP dataloader patching + removing custom sampler / dataloader logic (#2686)
* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* small change

* coderabbit commits, cleanup, remove dead code

* quarto fix

* patch fix

* review comments

* moving monkeypatch up one level

* fix
2025-05-21 11:20:20 -04:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
NanoCode012
3a5b495a74 Fix: improve doc on merge/inference cli visibility (#2674)
* feat: improve visibility for merge doc

* feat: add tip on reuse config between modes
2025-05-16 13:07:40 -04:00
xzuyn
f661858fc4 Print dataset name (#2668) [skip ci] 2025-05-16 13:06:58 -04:00
Eric Meier
c837c4a424 Add missing init file to liger plugin (#2670) [skip ci] 2025-05-16 13:06:46 -04:00
michelyang
c9797de6bb Add num_proc to fix data set slow processing issue (#2681) [skip ci] 2025-05-16 13:06:20 -04:00
Wing Lian
8f8a7afb05 Add ci and images for CUDA 12.8 for B200s (#2683) [skip ci]
* Add ci and images for CUDA 12.8 for B200s

* add comments explaining CI [skip e2e]
2025-05-16 13:06:08 -04:00
NanoCode012
86472715da fix: remove doc string imports in monkeypatches (#2671) [skip ci] 2025-05-16 13:05:55 -04:00
75 changed files with 2791 additions and 2733 deletions

View File

@@ -31,6 +31,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -94,6 +99,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -295,6 +295,7 @@ jobs:
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
@@ -341,6 +342,8 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
@@ -365,6 +368,12 @@ jobs:
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -60,7 +60,6 @@ quartodoc:
- core.trainers.mixins.optimizer
- core.trainers.mixins.rng_state_loader
- core.trainers.mixins.scheduler
- core.trainers.mixins.sequence_parallel
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=8.0,
cpu=16.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -633,7 +633,9 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -8,6 +8,10 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.

View File

@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
format):
```json
@@ -120,6 +120,12 @@ axolotl train my_training.yml
## Common Tasks {#sec-common-tasks}
::: {.callout-tip}
The same yaml file is used for training, inference, and merging.
:::
### Testing Your Model {#sec-testing}
After training, test your model:
@@ -128,6 +134,16 @@ After training, test your model:
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
More details can be found in [Inference](inference.qmd).
### Using a UI {#sec-ui}
Launch a Gradio interface:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
Launch a Gradio interface:
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
### Merging LoRA weights {#sec-merging-lora}
To merge the LoRA weights back into the base model, run:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
```
The merged model will be saved in the `{output_dir}/merged` directory.
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -25,6 +25,10 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
@@ -72,6 +76,10 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}

View File

@@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -41,7 +41,7 @@ When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
@@ -67,9 +67,11 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
...
```

View File

@@ -0,0 +1,48 @@
base_model: mistralai/Devstral-Small-2505
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: mistral_v7_tekken
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
logging_steps: 1
flash_attention: false
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
chat_template: qwen_25
rl: dpo
datasets:

View File

@@ -20,8 +20,9 @@ from transformers import (
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
LOG = logging.getLogger(__name__)
@@ -318,7 +319,8 @@ def load_model_and_tokenizer(
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:

View File

@@ -10,10 +10,10 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels

View File

@@ -59,6 +59,7 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
@@ -86,7 +87,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try:
@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -794,11 +798,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
@@ -1079,10 +1078,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
@@ -1170,7 +1165,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
@@ -1199,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if temp_trainer_cls is not None:
trainer_cls = temp_trainer_cls
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():

View File

@@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -68,10 +65,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -122,8 +115,8 @@ class AxolotlTrainer(
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -132,9 +125,7 @@ class AxolotlTrainer(
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
if self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
@@ -153,8 +144,7 @@ class AxolotlTrainer(
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Helper method to get the sampler for evaluation. Handles sample packing case.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -168,9 +158,7 @@ class AxolotlTrainer(
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
if use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
@@ -236,14 +224,6 @@ class AxolotlTrainer(
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
@@ -287,12 +267,7 @@ class AxolotlTrainer(
return dataloader
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
if self.args.sample_packing and self.args.eval_sample_packing is not False:
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
@@ -302,17 +277,6 @@ class AxolotlTrainer(
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size

View File

@@ -1,31 +1,15 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for axolotl"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
@@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled():
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
"""Extend the base DPOTrainer for axolotl helpers."""
tag_names = ["axolotl", "dpo"]
@@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
Overwrite the `push_to_hub` method in order to force-add the tags when pushing
the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`
for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
@@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
@@ -14,7 +13,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
is_peft_available,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -30,15 +29,13 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.extras.profiling import profiling_context
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -46,68 +43,18 @@ from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -6,4 +6,3 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,87 +0,0 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
)
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)

View File

@@ -9,8 +9,6 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:
@@ -216,14 +214,16 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)

View File

@@ -10,71 +10,73 @@
# License for the specific language governing permissions and limitations under
# the License.
"""
Base class for all plugins.
"""Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
from __future__ import annotations
import collections
import importlib
import logging
from typing import OrderedDict
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
import torch
from peft import PeftModel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
class BasePlugin:
"""
Base class for all plugins. Defines the interface for plugin methods.
Attributes:
None
"""Base class for all plugins. Defines the interface for plugin methods.
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is
created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training.
"""
def __init__(self):
"""
Initializes the BasePlugin.
"""
"""Initializes the BasePlugin."""
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
"""Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
Args:
cfg: The configuration for the plugin.
"""
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
"""Returns a pydantic model for the plugin's input arguments."""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
@@ -84,181 +86,164 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions before the model is loaded.
Args:
cfg (dict): The configuration for the plugin.
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
Args:
cfg: The global axolotl configuration.
Returns:
None
The first non-`None` trainer class returned by a plugin.
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
Returns:
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
self,
cfg: DictDefault,
trainer: Trainer,
optimizer: Optimizer,
num_training_steps: int,
) -> LRScheduler | None:
"""Creates and returns a learning rate scheduler.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
cfg: The configuration for the plugin.
trainer: The trainer object for training.
optimizer: The optimizer for training.
num_training_steps: Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Set up callbacks before creating the trainer.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
cfg: The configuration for the plugin.
model: The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
A list of callback functions to be added to the `TrainingArgs`.
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Adds callbacks to the trainer after creating the trainer. This is useful for
callbacks that require access to the model or trainer.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
A list of callback functions to be added
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
cfg: The axolotl configuration.
model: The loaded model.
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions after training is complete and the model is unloaded.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
cfg: The configuration for the plugin.
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
Loads a plugin based on the given plugin name.
"""Loads a plugin based on the given plugin name.
The plugin name should be in the format "module_name.class_name".
This function splits the plugin name into module and class, imports the module,
retrieves the class from the module, and creates an instance of the class.
The plugin name should be in the format "module_name.class_name". This function
splits the plugin name into module and class, imports the module, retrieves the
class from the module, and creates an instance of the class.
Parameters:
plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
Args:
plugin_name: The name of the plugin to be loaded. The name should be in the
format "module_name.class_name".
Returns:
BasePlugin: An instance of the loaded plugin.
An instance of the loaded plugin.
Raises:
ImportError: If the plugin module cannot be imported.
ImportError: If the plugin module cannot be imported.
"""
# split the plugin name into module and class
module_name, class_name = plugin_name.rsplit(".", 1)
@@ -284,28 +269,25 @@ def load_plugin(plugin_name: str) -> BasePlugin:
class PluginManager:
"""
The PluginManager class is responsible for loading and managing plugins.
It should be a singleton so it can be accessed from anywhere in the codebase.
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins (List[BasePlugin]): A list of loaded plugins.
plugins: A list of loaded plugins.
Methods:
get_instance(): Static method to get the singleton instance of PluginManager.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
get_instance(): Static method to get the singleton instance of `PluginManager`.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
_cfg = None
_instance: PluginManager | None = None
_cfg: DictDefault | None = None
def __new__(cls):
"""
Creates a new instance of PluginManager if it doesn't exist yet.
"""
"""Creates a new instance of PluginManager if it doesn't exist yet."""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
@@ -315,9 +297,8 @@ class PluginManager:
@staticmethod
def get_instance() -> "PluginManager":
"""
Returns the singleton instance of PluginManager.
If the instance doesn't exist, it creates a new one.
"""Returns the singleton instance of PluginManager. If the instance doesn't
exist, it creates a new one.
"""
if PluginManager._instance is None:
PluginManager()
@@ -332,17 +313,13 @@ class PluginManager:
self._cfg = cfg
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
"""Registers a new plugin by its name.
Parameters:
plugin_name (str): The name of the plugin to be registered.
Returns:
None
Args:
plugin_name: The name of the plugin to be registered.
Raises:
ImportError: If the plugin module cannot be imported.
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
@@ -352,12 +329,11 @@ class PluginManager:
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self):
"""
Returns a list of Pydantic classes for all registered plugins' input arguments.'
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
Returns:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
@@ -366,16 +342,17 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
preprocess: Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
@@ -387,83 +364,66 @@ class PluginManager:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
def pre_model_load(self, cfg: DictDefault):
"""Calls the pre_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
None
Args:
cfg: The configuration for the plugins.
"""
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `post_model_build` method of all registered plugins after the
model has been built / loaded, but before any adapters have been applied.
Args:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `pre_lora_load` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
"""
Calls the pre_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
"""
Calls the post_lora_load method of all registered plugins.
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_lora_load` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_model_load` method of all registered plugins after the model
has been loaded inclusive of any adapters.
Parameters:
cfg (dict): The configuration for the plugins.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Calls the `get_trainer_cls` method of all registered plugins and returns the
first non-`None` trainer class.
Args:
cfg: The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
The first non-`None` trainer class returned by a plugin.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
@@ -471,29 +431,25 @@ class PluginManager:
return trainer_cls
return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
def create_optimizer(self, trainer: Trainer) -> Optimizer | None:
"""Calls the `create_optimizer` method of all registered plugins and returns
the first non-`None` optimizer.
Parameters:
trainer (object): The trainer object for training.
Args:
trainer: The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
The created optimizer, or `None` if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(self.cfg, trainer)
@@ -502,17 +458,17 @@ class PluginManager:
return None
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
"""Calls the `create_lr_scheduler` method of all registered plugins and returns
the first non-`None` scheduler.
Parameters:
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Args:
trainer: The trainer object for training.
optimizer: The optimizer for training.
Returns:
object: The created learning rate scheduler, or None if none was found.
The created learning rate scheduler, or `None` if not found.
"""
for plugin in self.plugins.values():
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
@@ -525,16 +481,17 @@ class PluginManager:
return scheduler
return None
def add_callbacks_pre_trainer(self, cfg, model):
"""
Calls the add_callbacks_pre_trainer method of all registered plugins.
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
A list of callback functions to be added to the `TrainingArgs`.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -543,16 +500,17 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Calls the add_callbacks_post_trainer method of all registered plugins.
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Calls the `add_callbacks_post_trainer` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
A list of callback functions to be added to the `TrainingArgs`.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -561,41 +519,31 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the post_train method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
def post_train_unload(self, cfg: DictDefault):
"""Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
"""Base class for factories to create custom optimizers"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
) -> Optimizer | None:
pass

View File

@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -20,15 +20,11 @@ from torch import nn
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -170,10 +162,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -16,22 +16,12 @@ from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -160,9 +150,6 @@ def cce_forward(
)
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None, # type: ignore

View File

@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
apply_lce,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
TransformersModelT,
apply_lce,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -1,5 +1,4 @@
"""
Module for definition of GEGLU Triton kernels.
"""Module for definition of GEGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
@@ -12,8 +11,6 @@ import torch
import triton
import triton.language as tl
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
@triton.jit
def _geglu_fwd_kernel(

View File

@@ -0,0 +1,10 @@
"""Init for axolotl.loaders module"""
# pylint: disable=unused-import
# flake8: noqa
from .adapter import load_adapter, load_lora
from .constants import MULTIMODAL_AUTO_MODEL_MAPPING
from .model import ModelLoader
from .processor import load_processor
from .tokenizer import load_tokenizer

View File

@@ -0,0 +1,206 @@
"""Adapter loading functionality, including LoRA / QLoRA and associated utils"""
import logging
import os
import types
from typing import Any
import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Params4bit
from peft import (
AdaptionPromptConfig,
LoftQConfig,
LoraConfig,
PeftConfig,
PeftMixedModel,
PeftModel,
get_peft_model,
)
from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model: torch.nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)
return list(lora_module_names)
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
lora_target_modules = cfg.lora_target_modules or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
**lora_config_kwargs,
)
if config_only:
return None, lora_config
rank = int(os.environ.get("LOCAL_RANK", 0))
if (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
if rank == 0:
try:
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model)
return model, lora_config
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
return peft_model, lora_config
if adapter == "llama-adapter":
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config
raise NotImplementedError(f"{adapter} PEFT adapter not available")
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - llama_adapter")
peft_model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
torch_dtype=torch.float16,
)
else:
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()
return peft_model, peft_config

View File

@@ -0,0 +1,21 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
Gemma3ForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
}

View File

@@ -0,0 +1,754 @@
"""Model loader class implementation for loading, configuring, and patching various
models.
"""
import gc
import logging
import math
import os
from functools import cached_property
from importlib.util import find_spec
from typing import Any
import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import (
HfTrainerDeepSpeedConfig,
is_deepspeed_zero3_enabled,
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.loaders.adapter import load_adapter, load_lora
from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.patch_manager import PatchManager
from axolotl.loaders.utils import (
get_linear_embedding_layers,
get_module_class_from_name,
load_model_config,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
get_device_count,
get_device_type,
)
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class ModelLoader:
"""Manages model configuration, initialization and application of patches during
model loading.
This class orchestrates the entire process of loading a model from configuration to
final preparation. It handles device mapping, quantization, attention mechanisms,
adapter integration, and various optimizations.
The loading process includes:
- Loading and validating model configuration
- Applying monkey patches for optimizations / fixes
- Setting up device mapping (including multi-GPU configurations)
- Configuring quantization
- Setting attention mechanisms (Flash Attention, SDPA, etc.)
- Loading and initializing the model
- Applying adapters (LoRA, QLoRA, etc.)
Attributes:
model: The loaded model instance (available after load() is called).
model_kwargs: Dictionary of keyword arguments passed to model initialization.
base_model: Name or path of the base model to load.
model_type: Type of model to load (e.g., `AutoModelForCausalLM`).
model_config: Configuration object for the model.
auto_model_loader: class used for loading the model (default:
`AutoModelForCausalLM`).
"""
def __init__(
self,
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
):
"""Initializes the ModelLoader.
Args:
cfg: Configuration dictionary with model and training settings.
tokenizer: Tokenizer instance associated with the model.
processor: Optional processor for multimodal models. Defaults to None.
inference: Whether the model is being loaded for inference mode. Defaults
to False.
reference_model: Whether this is a reference model (used in setups like DPO
training). Defaults to False.
**kwargs: Additional keyword arguments (ignored).
"""
self.cfg = cfg
self.tokenizer = tokenizer
self.inference: bool = inference
self.reference_model: bool = reference_model
# Init model kwargs
self.model_kwargs: dict[str, Any] = {}
if cfg.overrides_of_model_kwargs:
for key, val in cfg.overrides_of_model_kwargs.items():
self.model_kwargs[key] = val
# Init model
self.model: PreTrainedModel | PeftModel | PeftMixedModel
self.base_model = cfg.base_model
self.model_type = cfg.type_of_model
# Init model config
self.model_config = load_model_config(cfg)
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
# Initialize the patch manager
self.patch_manager = PatchManager(
cfg=cfg,
model_config=self.model_config,
inference=inference,
)
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed."""
return find_spec("flash_attn") is not None
@cached_property
def qlora_fsdp(self):
"""Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora"
def load(self) -> tuple[PreTrainedModel, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
Returns:
A tuple with the loaded model and its LoRA configuration (if applicable).
"""
# Initial setup and patches
self.patch_manager.apply_pre_model_load_patches()
self._apply_pre_model_load_setup()
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
skip_move_to_device = self._build_model()
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
self._apply_post_model_load_setup()
# Load adapters (LoRA, etc.)
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
lora_config = self._load_adapters()
PLUGIN_MANAGER.post_lora_load(self.cfg, self.model)
# Apply remaining patches and finalize
self._apply_post_lora_load_setup(skip_move_to_device)
self.patch_manager.apply_post_model_load_patches(self.model)
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
return self.model, lora_config
def _apply_pre_model_load_setup(self):
"""Apply patches and setup configurations before model loading."""
self._set_auto_model_loader()
self._set_device_map_config()
if self.cfg.revision_of_model:
self.model_kwargs["revision"] = self.cfg.revision_of_model
self._set_quantization_config()
self._set_attention_config()
def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded."""
# Handle PeftModel if needed
if (
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
and not self.qlora_fsdp
):
self.model = self.model.merge_and_unload()
self._resize_token_embeddings()
self._adjust_model_config()
self._log_memory_usage()
self._configure_embedding_dtypes()
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
embeddings_len = (
math.ceil(len(self.tokenizer) / 32) * 32
if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer)
)
if hasattr(self.model, "get_input_embeddings") and (
self.model.get_input_embeddings().num_embeddings < embeddings_len
or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None and (
self.model_config.model_type != "llava"
):
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:
self.model.tie_weights()
def _adjust_model_config(self):
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
and self.model.config.max_position_embeddings
and self.cfg.sequence_len > self.model.config.max_position_embeddings
):
LOG.warning(
"increasing model.config.max_position_embeddings from "
f"{self.model.config.max_position_embeddings} to {self.cfg.sequence_len}"
)
self.model.config.max_position_embeddings = self.cfg.sequence_len
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "bos_token_id")
and self.model.config.bos_token_id
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
):
self.model.config.bos_token_id = self.tokenizer.bos_token_id
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "eos_token_id")
and self.model.config.eos_token_id
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
):
self.model.config.eos_token_id = self.tokenizer.eos_token_id
def _log_memory_usage(self):
"""Log device memory usage after model load."""
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
def _configure_embedding_dtypes(self):
"""Configure embedding module dtypes."""
# Get embedding modules
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
# Initial dtype conversion
if not self.cfg.fsdp:
# We don't run this during FSDP because this will leave mixed and bfloat16
# dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
embedding_modules = []
self._convert_embedding_modules_dtype(
embedding_modules,
dist_dtype=torch.float32,
before_kbit_train_or_finetune=True,
)
# Handle DeepSpeed Zero3
if is_deepspeed_zero3_enabled():
self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
if self.cfg.adapter in ["lora", "qlora"]:
needs_fa2_dtype = True
if self.cfg.gradient_checkpointing:
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
)
self._prepare_model_for_quantization()
# Convert dtypes if needed
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
and not self.qlora_fsdp
)
# CCE requires embedding layers to be in fp16/bf16 for backward pass
or self.cfg.cut_cross_entropy
)
if should_convert:
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
self._convert_embedding_modules_dtype(
embedding_modules=embedding_modules,
dist_dtype=self.cfg.torch_dtype,
before_kbit_train_or_finetune=False,
)
def _load_adapters(self) -> PeftConfig | None:
"""Load LoRA or other adapters."""
# Load LoRA or adapter
lora_config = None
if not self.reference_model or self.cfg.lora_model_dir:
# If we're not loading the reference model, then we're loading the model
# for training. Then, the DPO trainer doesn't want the PEFT model loaded
# over it, it just wants the LoRA / PEFT config.
if (
self.cfg.adapter
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(
self.model, self.cfg, inference=False, config_only=True
)
else:
self.model, lora_config = load_adapter(
self.model, self.cfg, self.cfg.adapter
)
return lora_config
def _apply_post_lora_load_setup(self, skip_move_to_device: bool):
"""Apply final optimizations and patches."""
# Place model on accelerator
if (
self.cfg.ddp
and not self.cfg.load_in_8bit
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device
):
# TODO: validate this conditional
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
self.model.is_parallelizable = True
self.model.model_parallel = True
if not any(
param.requires_grad
for _, param in self.model.named_parameters(recurse=True)
):
LOG.warning("There are no parameters that require gradient updates")
if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
self.model = BetterTransformer.transform(self.model)
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
should be set according to the type of the model.
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
def _set_device_map_config(self):
"""Setup `device_map` according to config"""
device_map = self.cfg.device_map
max_memory = self.cfg.max_memory
if self.cfg.gpu_memory_limit:
gpu_memory_limit = (
str(self.cfg.gpu_memory_limit) + "GiB"
if isinstance(self.cfg.gpu_memory_limit, int)
else self.cfg.gpu_memory_limit
)
max_memory = {}
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = self.auto_model_loader.from_config(
self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=self.cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if not is_deepspeed_zero3_enabled():
self.model_kwargs["device_map"] = device_map
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO: can we put the reference model on it's own gpu? I think we have to move
# logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
def _set_quantization_config(self):
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
LOG.warning(
"model config does not contain quantization_config information"
)
else:
if self.cfg.gptq_disable_exllama is not None:
self.model_config.quantization_config["disable_exllama"] = (
self.cfg.gptq_disable_exllama
)
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
if (
self.cfg.adapter in ["qlora", "lora"]
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
"llm_int8_has_fp16_weight": False,
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp
):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
# Exclude mamba blocks from int8 quantization for jamba
if self.cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def _configure_zero3_memory_efficient_loading(self):
"""Set the deepspeed config to load the model into RAM first before moving
to VRAM.
We need to return `hf_ds_cfg` as it needs to exist before model loading.
"""
hf_ds_cfg = None
if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)
hf_ds_cfg.fill_match(
"train_micro_batch_size_per_gpu", self.cfg.micro_batch_size
)
hf_ds_cfg.fill_match(
"gradient_accumulation_steps", self.cfg.gradient_accumulation_steps
)
hf_ds_cfg.fill_match(
"train_batch_size",
int(os.getenv("WORLD_SIZE", "1"))
* self.cfg.micro_batch_size
* self.cfg.gradient_accumulation_steps,
)
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
)
return hf_ds_cfg
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
if (
self.qlora_fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and (
self.cfg.model_config_type == "dbrx"
or self.cfg.qlora_sharded_model_loading
)
):
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
)
quantization_config = (
quantization_config or self.model_kwargs["quantization_config"]
)
self.model = load_sharded_model_quant(
self.base_model,
self.model_config,
self.cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# TODO: Do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
self.model_kwargs["device"] = torch.cuda.current_device()
self.model_kwargs.pop("torch_dtype", None)
self.model_kwargs.pop("device_map", None)
self.model = MambaLMHeadModel.from_pretrained(
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if (
self.cfg.fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
):
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
return skip_move_to_device
def _set_z3_leaf_modules(self):
from deepspeed.utils import set_z3_leaf_modules
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,
[
get_module_class_from_name(self.model, module_name)
for module_name in moe_blocks
],
)
def _prepare_model_for_quantization(self):
"""Prepare loaded model for quantization."""
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
loftq_bits = (
self.cfg.peft
and self.cfg.peft.loftq_config
and self.cfg.peft.loftq_config.loftq_bits
)
if self.cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if (
self.qlora_fsdp
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
or is_deepspeed_zero3_enabled()
):
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing
)
def _convert_embedding_modules_dtype(
self,
embedding_modules: list[str],
dist_dtype: torch.dtype,
before_kbit_train_or_finetune: bool,
):
for name, module in self.model.named_modules():
if "norm" in name:
module.to(dist_dtype)
if before_kbit_train_or_finetune:
if name.endswith(".gate"):
module.to(dist_dtype)
if self.model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
module.to(dist_dtype)

View File

@@ -0,0 +1,380 @@
"""Patch manager class implementation to complement `axolotl.loaders.ModelLoader`.
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
import logging
from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager:
"""Manages the application of patches during the model loading process."""
def __init__(
self,
cfg: DictDefault,
model_config: PretrainedConfig | addict.Dict,
inference: bool = False,
):
"""Initialize the `PatchManager`.
Args:
cfg: Configuration dictionary with model and training settings.
model_config: Configuration object for the model.
inference: Whether the model is being loaded for inference mode.
"""
self.cfg = cfg
self.model_config = model_config
self.inference = inference
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed."""
return importlib.util.find_spec("flash_attn") is not None
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._apply_flash_attention_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_flex_attention_patches()
self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
self._patch_attention()
self._apply_multipack_patches()
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_unsloth_self_attention_patch()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
self.cfg.model_config_type == "llama4"
and self.cfg.llama4_linearized_experts
):
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,
)
patch_gemma3conditionalgeneration_forward()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_offload_wrapper,
)
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
)
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
def _apply_mistral_cross_entropy_patch(self):
"""Apply Mistral cross entropy patch if configured."""
if (
self.cfg.model_config_type == "mistral"
and self.cfg.flash_attn_cross_entropy_loss
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
patch_mistral_cross_entropy,
)
patch_mistral_cross_entropy()
def _apply_unsloth_self_attention_patch(self):
"""Apply Unsloth self-attention patches if configured."""
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def _apply_multipack_patches(self):
"""Apply multipack patches if necessary."""
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
# Get automap config if it exists
auto_map_config = None
if isinstance(self.model_config, dict) and "auto_map" in self.model_config:
auto_map_config = self.model_config["auto_map"]
elif hasattr(self.model_config, "auto_map"):
auto_map_config = self.model_config.auto_map
# Determine if the model has remote code
if auto_map_config is not None:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
)
if self.cfg.is_llama_derived_model:
self._patch_loss_llama()
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
return
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
)
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def _patch_llama_flash_attention(self, packed=False):
"""Apply Flash Attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if packed:
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
def _patch_llama_xformers_attention(self):
"""Apply xformers attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
self._patch_loss_llama()
if self.cfg.flash_attention:
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches."""
if (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and not self.inference
):
# TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...")
replace_llama_qkv_with_fused(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(peft_model=model)
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(peft_model=model, cfg=self.cfg)
if self.cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
def _apply_lora_kernel_patch(self, model):
"""Apply LoRA kernel patches."""
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)

View File

@@ -0,0 +1,56 @@
"""Processor loading functionality for multi-modal models"""
import logging
from typing import Any
import transformers
from transformers import (
AutoProcessor,
PreTrainedTokenizerBase,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
# Attempt to load image size from processor if available
if (
cfg.image_size is None
and hasattr(processor, "size")
and any(dim in processor.size for dim in ["width", "height"])
):
im_width = None
im_height = None
if "width" in processor.size:
im_width = processor.size["width"]
if "height" in processor.size:
im_height = processor.size["height"]
# If both width and height are set, use a tuple
if im_width is not None and im_height is not None:
cfg.image_size = (im_width, im_height)
# If only width is set, use as integer
elif im_width is not None:
cfg.image_size = im_width
# If only height is set, use as integer
elif im_height is not None:
cfg.image_size = im_height
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
return processor

View File

@@ -0,0 +1,281 @@
"""Tokenizer loading functionality and associated utils"""
import json
import logging
import os
import transformers
from transformers import (
AddedToken,
AutoTokenizer,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.distributed import (
barrier,
is_local_main_process,
is_main_process,
)
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens
already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
for token_id, new_value in token_id_mappings.items():
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
if entry_id == token_id:
del tokenizer_data["model"]["vocab"][entry_val]
tokenizer_data["model"]["vocab"][new_value] = token_id
break
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer

View File

@@ -0,0 +1,211 @@
"""Utilities for axolotl.loaders module"""
import contextlib
import logging
from typing import Type
import addict
import torch
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def get_module_class_from_name(
module: torch.nn.Module, name: str
) -> Type[torch.nn.Module] | None:
"""Gets a class from a module by its name. Copied from `accelerate.utils.dataclasses`
(https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L2805).
Args:
module: The module to get the class from.
name: The name of the class.
Returns:
The class type of the matching module, or `None` if no match is found.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
"""Validates and adjusts model config based on `axolotl` config.
This function performs several important checks and adjustments:
- Disables model caching for better memory efficiency
- Handles multimodal model-specific configurations
- Validates quantization settings
- Ensures proper LoRA configuration when using adapters with new tokens
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model_config: The model's configuration object from `transformers`.
Raises:
ValueError: If a multimodal model lacks text configuration, if GPTQ settings
are inconsistent, or if LoRA `modules_to_save` is improperly configured
with new tokens.
"""
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
if cfg.is_multimodal:
# For multimodal configs, use_cache is set in the text_config
if hasattr(model_config, "get_text_config"):
text_config = model_config.get_text_config()
if hasattr(text_config, "use_cache"):
text_config.use_cache = False
else:
raise ValueError(
"No text config found for multimodal model. Please raise an Issue with model details."
)
# Check if image_size is not set and load image size from model config if available
if (
cfg.image_size is None
and hasattr(model_config, "vision_config")
and hasattr(model_config.vision_config, "image_size")
):
cfg.image_size = model_config.vision_config.image_size
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save_joined = ", ".join(
map(lambda x: f"`{x}`", lora_modules_to_save)
)
raise ValueError(
"`lora_modules_to_save` not properly set when adding new tokens. "
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
)
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
"""Loads and configures a model configuration from HuggingFace or local sources.
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
A configured model configuration object (`AutoConfig` instance), or a simple
dictionary configuration for special cases like Mamba models.
Raises:
ValueError: If configuration loading fails for reasons other than special cases
that are handled (e.g., Mamba models).
"""
model_config_name = cfg.base_model_config or cfg.base_model
if not model_config_name and cfg.tokenizer_config:
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,
)
except ValueError as error:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise error
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16):
"""Ensures all modules in the model are converted to the specified data type."""
for name, module in model.named_modules():
weight_mismatch = False
with contextlib.suppress(AttributeError):
weight_mismatch = module.weight.dtype != dtype
bias_mismatch = False
with contextlib.suppress(AttributeError):
bias_mismatch = module.bias.dtype != dtype
if weight_mismatch:
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
if bias_mismatch:
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
if weight_mismatch or bias_mismatch:
module.to(dtype)
def get_linear_embedding_layers(model_type: str) -> list[str]:
"""Returns layer names of linear embeddings needed for LoRA based on model type."""
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]

View File

@@ -1,11 +0,0 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -1,131 +0,0 @@
"""
Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
rank = dist.get_rank()
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(position_ids: torch.Tensor | None):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
position_ids: Optional tensor of position IDs (for sample packed data).
"""
from ring_flash_attn import update_ring_flash_attn_params
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -5,10 +5,10 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.offload_cpu import (
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
Disco,
)

View File

@@ -75,4 +75,4 @@ def patch_peft_prep_code():
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.loaders.model.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -0,0 +1,22 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)
__all__ = (
"get_ring_attn_group",
"patch_prepare_data_loader",
"patch_prepare_device_mesh",
"register_ring_attn",
"set_ring_attn_group",
"update_ring_attn_params",
)

View File

@@ -0,0 +1,225 @@
"""Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
We also provide some patches for accelerate functions to prepare the dataloader for
sequence parallelism training.
"""
import inspect
import accelerate
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size"""
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
def get_ring_attn_group() -> dist.ProcessGroup:
"""Getter for ring attention group on this rank."""
if RING_ATTN_GROUP is None:
raise RuntimeError("register_ring_attn() not yet called")
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""Setter for ring attention group on this rank."""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:
from axolotl.monkeypatch.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(position_ids: torch.Tensor | None):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
position_ids: Optional tensor of position IDs (for sample packed data).
"""
from ring_flash_attn import update_ring_flash_attn_params
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
original_source = inspect.getsource(original_fn)
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
raise RuntimeError(
"SP patch failed - target snippet not found. "
"Check accelerate's version or update the patch."
)
patched_source = original_source.replace(
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace
)
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
"""
def _prepare_device_mesh(self):
"""Prepare the device mesh for distributed training. The dataloader will
determine how to load data based on the device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
if (
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
and hasattr(self.state, "ds_device_mesh")
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
)
device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# Replace the original method with our new method
# pylint: disable=protected-access
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
)

View File

@@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(f"Should train: {should_train}")
# turn not trainable, skip having to find the turn indices
# unless last turn and train_on_eos/train_on_eot is all
if not should_train and (
self.train_on_eos != "all" and self.train_on_eot != "all"
):
if index == len(turns) - 1:
LOG.warning(
"Last turn is not trainable, skipping having to find the turn indices. "
"This may cause incorrect last EOT/EOS token to be unmasked."
"This is likely a dataset design issue. Please ensure last turn is trainable."
)
continue
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")

View File

@@ -28,11 +28,15 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
@@ -76,7 +80,8 @@ def setup_model_and_tokenizer(
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
model.generation_config.do_sample = True
@@ -113,7 +118,8 @@ def setup_reference_model(
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
model_loader = ModelLoader(cfg, tokenizer, reference_model=True)
model_ref, _ = model_loader.load()
return model_ref
@@ -209,6 +215,7 @@ def execute_training(
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
)
)

View File

@@ -1,6 +1,7 @@
"""MLFlow module for trainer callbacks"""
import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
if should_log_artifacts():
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -11,9 +11,10 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.utils import load_model_config
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)

View File

@@ -1,6 +1,7 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import inspect
import torch
import torch.distributed as dist
@@ -9,8 +10,11 @@ from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.attention.ring_attn.patch import (
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -168,6 +172,8 @@ class SequenceParallelContextManager:
sequence_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
@@ -176,14 +182,17 @@ class SequenceParallelContextManager:
sequence_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Initialize sequence parallel group details
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
@@ -204,19 +213,59 @@ class SequenceParallelContextManager:
)
def __enter__(self):
self._register_model_hooks()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs and get original sequence length and padding info
kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(batch=kwargs)
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
return args, kwargs
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self.gather_outputs(output)
output = self._gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
@@ -239,15 +288,7 @@ class SequenceParallelContextManager:
model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:

View File

@@ -10,6 +10,7 @@ import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
@@ -17,7 +18,6 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
@@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
num_proc=cfg.dataset_processes,
**map_kwargs,
)

View File

@@ -484,7 +484,7 @@ def get_dataset_wrapper(
}
LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
if (

View File

@@ -1,14 +0,0 @@
"""
helpers for lora embeddings
"""
def get_linear_embedding_layers(model_type):
"""
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]

File diff suppressed because it is too large Load Diff

View File

@@ -470,6 +470,16 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_with_s2attn(cls, data):
if data.get("sample_packing") and data.get("s2_attention"):
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing."
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):

View File

@@ -1,13 +1,12 @@
"""
unit tests for axolotl.core.trainer_builder
"""
"""Unit tests for axolotl.core.trainer_builder"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="cfg")
@@ -49,7 +48,7 @@ def fixture_tokenizer(cfg):
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
return ModelLoader(cfg, tokenizer).load()
class TestHFRLTrainerBuilder:
@@ -65,3 +64,27 @@ class TestHFRLTrainerBuilder:
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
cfg.rl = RLType.KTO
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
with pytest.raises(
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,

View File

@@ -6,9 +6,9 @@ import unittest
import transformers
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from ..utils import with_temp_dir
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
@with_temp_dir
def test_mistral_multipack(self, temp_dir):
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
assert (
"torch.jit"

View File

@@ -10,7 +10,7 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
@@ -84,16 +84,16 @@ class TestRingAttention:
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@@ -313,18 +313,21 @@ class TestApplySequenceParallelism:
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
def test_world_size_one(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
@@ -336,8 +339,11 @@ class TestApplySequenceParallelism:
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
@@ -359,8 +365,11 @@ class TestApplySequenceParallelism:
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
@@ -419,8 +428,13 @@ class TestApplySequenceParallelism:
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
import torch
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
)
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():

View File

@@ -9,11 +9,11 @@ from typing import Optional
import pytest
from pydantic import ValidationError
from axolotl.loaders.utils import check_model_config
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation):
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):
test_cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
}
| minimal_cfg
)
with pytest.raises(
ValidationError,
match=r".*shifted-sparse attention does not currently support sample packing*",
):
validate_config(test_cfg)
class TestTorchCompileValidation(BaseValidation):
"""

View File

@@ -1,7 +1,8 @@
"""
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the
`deduplicate_and_log_datasets` function.
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
@@ -11,20 +12,19 @@ from unittest.mock import patch
import pytest
from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""
Validates deduplication results and size consistency.
"""Validates deduplication results and size consistency.
Parameters:
- actual_dataset: Deduplicated dataset.
@@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
class TestDeduplicateIndividualFunctions(unittest.TestCase):
"""
test class for deduplication function in data utils
"""
"""Test class for deduplication function in data utils"""
def setUp(self):
# Sample data with duplicates
@@ -248,7 +246,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -272,7 +270,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
class TestWrongCollisions(unittest.TestCase):
"""Creating mock datasets for testing wrong collisions"""
"""Creating mock datasets for testing wrong collisions."""
def setUp(self):
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}

View File

@@ -1,18 +1,18 @@
"""Module for testing models utils file."""
"""Module for `axolotl.loaders`."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model
class TestModelsUtils:
"""Testing module for models utils."""
"""Testing module for `axolotl.loaders`."""
def setup_method(self) -> None:
# load config
@@ -50,7 +50,8 @@ class TestModelsUtils:
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
# pylint: disable=protected-access
self.model_loader._set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
@@ -59,29 +60,6 @@ class TestModelsUtils:
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
"base_model": "",
"model_type": "AutoModelForCausalLM",
}
)
# Mock out call to HF hub
with patch(
"axolotl.utils.models.load_model_config"
) as mocked_load_model_config:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@@ -99,7 +77,8 @@ class TestModelsUtils:
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
# pylint: disable=protected-access
self.model_loader._set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")

View File

@@ -2,9 +2,9 @@
tests for loading loras
"""
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# pylint: disable=duplicate-code
minimal_config = DictDefault(
@@ -46,7 +46,7 @@ class TestLoRALoad:
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()
def test_load_lora_weights_empty_dropout(self):
cfg = DictDefault(
@@ -67,4 +67,4 @@ class TestLoRALoad:
normalize_config(cfg)
assert cfg.lora_dropout == 0.0
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()

View File

@@ -6,8 +6,8 @@ import unittest
import pytest
from axolotl.loaders import load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline