Compare commits

...

44 Commits

Author SHA1 Message Date
Dan Saunders
4ac65462f0 precommit 2025-03-21 16:43:14 +00:00
Dan Saunders
ce35b2a95f precommit 2025-03-21 16:36:56 +00:00
Dan Saunders
ab3b36339a fix tests 2025-03-21 16:36:54 +00:00
Dan Saunders
22cfa42961 small updates 2025-03-21 16:36:34 +00:00
Dan Saunders
0b2c2ed68c refactors, SP mixin 2025-03-21 16:36:34 +00:00
Dan Saunders
2f0b4626b9 review comments, docstrings 2025-03-21 16:36:32 +00:00
Dan Saunders
a26985c53c small changes 2025-03-21 16:36:17 +00:00
Dan Saunders
c1a58339e8 add SP doc, review comments 2025-03-21 16:36:17 +00:00
Dan Saunders
411df76a97 bugfix 2025-03-21 16:36:17 +00:00
Dan Saunders
a09d1ccbf2 removing print statement 2025-03-21 16:36:17 +00:00
Dan Saunders
2727d86544 non-seq2se1 collator fix 2025-03-21 16:36:17 +00:00
Dan Saunders
64c203cdef sampler / dataloader refactor 2025-03-21 16:36:17 +00:00
Dan Saunders
7d7042f602 test fix 2025-03-21 16:36:17 +00:00
Dan Saunders
d187f1f8e2 using field validator instead of model validator 2025-03-21 16:36:17 +00:00
Dan Saunders
1cced52719 rename file, delete another 2025-03-21 16:36:17 +00:00
Dan Saunders
11321b17e7 removing flash-attn from requirements.txt (in setup.py extras already) 2025-03-21 16:36:17 +00:00
Wing Lian
7a1a211c99 move ring flash attn to extras with flash-attn (#2414) 2025-03-21 16:36:17 +00:00
Dan Saunders
e1a02a32b5 fix 2025-03-21 16:36:17 +00:00
Dan Saunders
a6ef6c7764 fix 2025-03-21 16:36:17 +00:00
Dan Saunders
cb3a9e99a3 gracefully handle no ring-flash-attn 2025-03-21 16:36:17 +00:00
Dan Saunders
3ae47ec7de actually isolate CLI tests 2025-03-21 16:36:17 +00:00
Dan Saunders
e36dc763ab isolate cli tests 2025-03-21 16:36:17 +00:00
Dan Saunders
03027cf6bf pernicious Fire CLI bugfix 2025-03-21 16:36:16 +00:00
Dan Saunders
0ade60d455 another import scoping change 2025-03-21 16:35:56 +00:00
Dan Saunders
02e1a42f04 scoping down problematic import 2025-03-21 16:35:56 +00:00
Dan Saunders
919b88f11b update config.qmd and rename option 2025-03-21 16:35:55 +00:00
Dan Saunders
345a9dd831 removing some obvious comments 2025-03-21 16:35:38 +00:00
Dan Saunders
4ff97bc9d4 eval dataloader and sampler changes 2025-03-21 16:35:38 +00:00
Dan Saunders
d0e178d52f remove debug logs and simplify 2025-03-21 16:35:38 +00:00
Dan Saunders
5731cdc0cf fixing sample packing 2025-03-21 16:35:38 +00:00
Dan Saunders
b7738d57c4 working multi-group SP 2025-03-21 16:35:38 +00:00
Dan Saunders
698e599bf7 precommit fixes 2025-03-21 16:35:38 +00:00
Dan Saunders
1d339e4007 fixes 2025-03-21 16:35:38 +00:00
Dan Saunders
4190ad0647 updates 2025-03-21 16:35:36 +00:00
Dan Saunders
b44a207248 update 2025-03-21 16:35:10 +00:00
Dan Saunders
51c326150b pytest 2025-03-21 16:35:10 +00:00
Dan Saunders
14baaf6e0a updates 2025-03-21 16:35:10 +00:00
Dan Saunders
f487910444 removing unused code 2025-03-21 16:35:08 +00:00
Dan Saunders
c5071dfd8a fix req 2025-03-21 16:34:12 +00:00
Dan Saunders
e323145ba9 remove errant file 2025-03-21 16:34:12 +00:00
Dan Saunders
7efc787ac8 cleanup 2025-03-21 16:34:12 +00:00
Dan Saunders
dce61cdab1 progress on ring attn impl 2025-03-21 16:34:12 +00:00
Dan Saunders
bd952de9d2 progress on ring attn impl 2025-03-21 16:34:10 +00:00
Dan Saunders
3f8a43cab6 adding easy_context as integration for now 2025-03-21 16:33:46 +00:00
31 changed files with 1532 additions and 648 deletions

View File

@@ -98,8 +98,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -172,8 +173,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |

View File

@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -3,9 +3,10 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 /workspace/axolotl/tests/cli
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/

View File

@@ -32,6 +32,9 @@ tokenizer_legacy:
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
# (Internal use only)
# Used to identify which the model is based on
@@ -617,6 +620,14 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -0,0 +1,90 @@
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -4,7 +4,6 @@
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.3
@@ -36,6 +35,7 @@ einops
colorama
numba
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy

View File

@@ -17,11 +17,7 @@ def parse_requirements():
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -39,7 +35,6 @@ def parse_requirements():
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
@@ -124,9 +119,8 @@ setup(
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.4.post1",
],
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",

View File

@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `do_train`.

View File

@@ -36,7 +36,7 @@ from transformers import (
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
@@ -762,6 +762,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
@@ -845,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
@@ -875,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
@@ -908,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,

View File

@@ -0,0 +1,18 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -1,365 +1,47 @@
"""
module for customized trainers
"""
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
# pylint: disable=too-many-lines
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Dict, Literal, Optional
from typing import Any, Literal
import datasets
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import (
BatchSampler,
DataLoader,
RandomSampler,
Sampler,
SequentialSampler,
)
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.core.trainers.mixins import (
OptimizerMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
LOG = logging.getLogger(__name__)
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class OptimizerMixin(Trainer):
"""
Mixin class for shared handling of building custom optimizers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -394,8 +82,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler:
"""
Helper method to create a `MultipackBatchSampler` for multipacking sequences
for training.
Args:
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
dataset: Dataset to sample from.
Returns:
Multipack (sample packing) batch sampler.
"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
@@ -406,130 +106,223 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
else:
# Default to parent class implementation for standard random sampling
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
# Apply multipack wrapper if needed
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
)
return base_sampler
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
# Apply multipack wrapper if needed
if use_multipack:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
)
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
sampler = self._get_train_sampler()
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
eval_sampler = self._get_eval_sampler(eval_dataset)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
) -> torch.utils.data.Sampler | None:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
logs: The values to log.
start_time: The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
def training_step(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class ReLoRATrainer(AxolotlTrainer):
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Trainer subclass that uses the OneCycleLR scheduler
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
tag_names = ["axolotl", "relora"]
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]
return loss

View File

@@ -13,10 +13,10 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)

View File

@@ -0,0 +1,32 @@
"""Module for mamba trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss

View File

@@ -0,0 +1,8 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,201 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
class OptimizerMixin(Trainer):
"""Mixin class for shared handling of building custom optimizers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer

View File

@@ -0,0 +1,113 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from transformers.trainer import Trainer
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler

View File

@@ -0,0 +1,131 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -0,0 +1,43 @@
"""Module for ReLoRA trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler

View File

@@ -1,16 +1,23 @@
"""
module for TRL PPO training
"""
"""Module for TRL PPO trainer"""
import torch
from tqdm import tqdm
from trl import PPOTrainer
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer):
"""
wrapper for ppo trainer to handle customizations
"""
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
def train(
self,
@@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer):
"batch_size": 16,
}
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
for _, batch in tqdm(enumerate(self.dataloader)):
query_tensors = batch["input_ids"]
# generate model response
@@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer):
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -0,0 +1,33 @@
"""Utils for Axolotl trainers"""
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs

View File

@@ -207,14 +207,19 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
This code is duplicated due to HF TrainingArguments not setting output_dir with a
default value so it can't be used as a mixin.
"""

View File

@@ -0,0 +1,89 @@
"""
Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
configure_logging()
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Detailed logging of group formation
rank = dist.get_rank()
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -169,7 +169,7 @@ def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate backend configurations.
Execute the training process with appropriate SDP kernel configurations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -177,9 +177,6 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...

View File

@@ -1,14 +1,59 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
@dataclass
class DataCollatorForSeq2Seq:
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
labels = None
@@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch[key] = batch[key][:, start_idx:end_idx]
# Special handling for position_ids
if key == "position_ids" and self.local_rank > 0:
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
@@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -125,6 +125,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
MULTIMODEL_AUTO_MODEL_MAPPING = {
"llava": LlavaForConditionalGeneration,
"mllama": MllamaForConditionalGeneration,
}
# copied from accelerator.FullyShardedDataParallelPlugin
@@ -476,7 +481,7 @@ class ModelLoader:
else:
self.text_model_config = self.model_config
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# load any patches from plugins
@@ -547,6 +552,14 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -603,7 +616,7 @@ class ModelLoader:
patch_self_attn_lora()
def patch_llama_derived_model(self) -> None:
def patch_llama_derived_model(self):
"""Modify all llama derived models in one block"""
self.patch_loss_llama()
@@ -653,24 +666,15 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention."
)
def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
- when using a multi modality model, self.AutoModelLoader should
be set according to model type of the model
def set_auto_model_loader(self):
"""
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
should be set according to the type of the model.
"""
if self.cfg.is_multimodal:
if self.model_config.model_type == "llava":
self.AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif self.model_config.model_type == "mllama":
self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
def set_device_map_config(self) -> None:
@@ -695,7 +699,7 @@ class ModelLoader:
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = self.AutoModelLoader.from_config(
model_canvas = self.auto_model_loader.from_config(
self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
@@ -916,7 +920,23 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(
config=self.model_config,
)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
@@ -958,7 +978,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq:
self.model = self.AutoModelLoader.from_pretrained(
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -991,7 +1011,7 @@ class ModelLoader:
if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1011,7 +1031,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1307,7 +1327,7 @@ def load_model(
"""
Load a model for a given configuration and tokenizer.
"""
loader = ModelLoader(
model_loader = ModelLoader(
cfg,
tokenizer,
processor=processor,
@@ -1315,7 +1335,7 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return loader.load_model()
return model_loader.load_model()
def load_adapter(model, cfg, adapter, inference=False):

View File

@@ -104,9 +104,7 @@ def allocate(
class MultipackBatchSampler(BatchSampler):
"""
Batch Sampler class for multipack
"""
"""Batch sampler class for multipack"""
def __init__(
self,

View File

@@ -1,4 +1,4 @@
"""Main Axolotl input configuration Pydantic models"""
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
@@ -245,6 +245,8 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
added_tokens_overrides: dict[int, str] | None = None
@@ -1102,6 +1104,29 @@ class AxolotlInputConfig(
return data
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_config(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
return value
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing is not False:
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
/ cfg.batch_size
)
)
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps

View File

@@ -0,0 +1,207 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.utils.dict import DictDefault
# Use a single patch for ring_flash_attn if it's not available
ring_flash_attn_mock = MagicMock()
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
return cfg
class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism."""
def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
# Adjust as if this was the second slice (start_idx = 4)
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
assert torch.all(adjusted[0] == expected_first_seq)
assert torch.all(adjusted[1] == expected_second_seq)
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Just verify that new_group was called
mock_new_group.assert_called()
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
[
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
)
assert torch.all(result["input_ids"] == expected_input_ids)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)

View File

@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"dataset_exact_deduplication": True,
@@ -282,6 +284,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1,
}
)
normalize_config(self.cfg_1)
def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""