Compare commits

...

19 Commits

Author SHA1 Message Date
coderabbitai[bot]
e23a5c9fda 📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length
Docstrings generation was requested by @mhenrichsen.

* https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776

The following files were modified:

* `src/axolotl/utils/data/pretraining.py`
* `src/axolotl/utils/data/rl.py`
* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
* `tests/test_data.py`
* `tests/test_trainer_utils.py`
2025-05-15 11:02:45 +00:00
mhenrhcsen
5d7a61576d Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py.
- Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
2025-05-15 12:55:09 +02:00
mhenrhcsen
5ecf22b54e Merge branch 'main' of github.com:axolotl-ai-cloud/axolotl into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-14 13:36:43 +02:00
mhenrhcsen
9c5b8da22f fix merge conflicts 2025-05-14 13:33:42 +02:00
Wing Lian
c0a0c7534c Activation checkpointing with offloading to disk with prefetch (#2663)
* offload activations to disk instead of CPU RAM

* add prefetch

* Disco :dance:

* include offload_disk in e2e test for AC

* document and make sure to cleanup

* fix annotation to match docs

* fix docs build

* address PR feedback
2025-05-13 16:39:39 -04:00
Wing Lian
7fa1089cea Atropos support (#2666) [skip ci]
* allow peft+liger+grpo and custom vllm serve for atropos support

* set trainer class for RL
2025-05-13 08:30:58 -04:00
mhenrhcsen
fea6649518 increased test coverage 2025-05-13 08:58:34 +02:00
mhenrhcsen
124ad2b968 lint 2025-05-13 08:35:16 +02:00
Dan Saunders
80304c26a7 SP GRPO support + batch SP fixes (#2643)
* ctx manager for SP

* updates

* update

* further simplifying

* simplifying

* simplifying

* reorg

* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* fix

* fixes for batch API funcs, simplify

* fix

* grpo sp support

* progress

* stronger subclassing of TRL GRPO trainer; custom distributed sampler

* subclassing constructor

* progress

* finalizing SP + GRPO trainer

* minimize diffs to GRPO trainer

* remove (most of) the custom GRPO trainer logic

* debug

* debug

* update

* update

* update

* progress

* cleanup

* cleanup

* minor changes

* update

* update

* update

* small changes

* updates

* cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint

* spacing

* cleanup; log in pydantic model config only on main process

* remove comment

* fix sp sampler, update to latest upstream code, doc

* add docs

* update quartodoc autodoc contents

* fix, simplifications

* fixes + simplifications

* review comments

* lint

* removing main process only logs in favor of #2608

* fixes, additional smoke test

* updatse

* more tests

* update

* fix grad accum bug (sort of)

* lint, tests

* todo
2025-05-12 17:52:40 -04:00
mhenrhcsen
767c2340f1 docstring for tests 2025-05-12 22:57:43 +02:00
mhenrhcsen
f6623c34cc Linting fix 2025-05-12 22:53:30 +02:00
mhenrhcsen
5dd8f0b2b8 Fixes comments from winglian 2025-05-12 22:43:15 +02:00
NanoCode012
67c4ea9c7c fix: disable auto lora kernel if dropout nonzero (#2655) [skip ci]
* fix: disable auto lora kernel if dropout nonzero

* Add comment from PR feedback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-12 16:23:53 -04:00
Wing Lian
526ddb886d guard on deleting secrets from env (#2653) [skip ci] 2025-05-12 14:18:42 -04:00
Wing Lian
f34eef546a update doc and use P2P=LOC for brittle grpo test (#2649)
* update doc and skip brittle grpo test

* fix the path to run the multigpu tests

* increase timeout, use LOC instead of NVL

* typo

* use hf cache from s3 backed cloudfront

* mark grpo as flaky test dues to vllm start
2025-05-12 14:17:25 -04:00
mhenrhcsen
be3c6bbd85 fix linting issues 2025-05-12 14:46:57 +02:00
mhenrhcsen
f07db4f853 Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt.
- Improved readability by restructuring the code and removing redundant checks.
- Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
2025-05-12 14:40:10 +02:00
mhenrhcsen
17a5838d38 lint 2025-05-12 14:36:43 +02:00
mhenrhcsen
9f68918f13 Implement configurable handling of excess tokens in datasets
- Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length.
- Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method.
- Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior.
- Enhanced logging to reflect truncation actions in dataset processing.

This change improves flexibility in managing sequence lengths during training and evaluation.
2025-05-12 14:08:43 +02:00
45 changed files with 3069 additions and 624 deletions

View File

@@ -3,7 +3,7 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'

View File

@@ -44,96 +44,102 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
@@ -145,14 +151,20 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -210,7 +222,7 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
needs: [preload-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
@@ -222,14 +234,20 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5

View File

@@ -57,8 +57,10 @@ async def handler(job):
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
if "WANDB_API_KEY" in os.environ:
del os.environ["WANDB_API_KEY"]
if "HF_TOKEN" in os.environ:
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -48,8 +48,23 @@ quartodoc:
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.mamba
- core.trainers.relora
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
- core.trainers.utils
- title: Mixins
desc: Mixin classes for augmenting trainers
contents:
- core.trainers.mixins.optimizer
- core.trainers.mixins.rng_state_loader
- core.trainers.mixins.scheduler
- core.trainers.mixins.sequence_parallel
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:
- utils.ctx_managers.sequence_parallel
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
@@ -86,7 +101,7 @@ quartodoc:
- kernels.swiglu
- kernels.quantize
- kernels.utils
- title: MonkeyPatches
- title: Monkey Patches
desc: Runtime patches for model optimizations
contents:
- monkeypatch.llama_attn_hijack_flash
@@ -124,7 +139,8 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:

View File

@@ -6,7 +6,7 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=90 * 60, # 90 min
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,

View File

@@ -19,7 +19,7 @@ coverage:
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
only_pulls: true
flags: null
paths: null
patch:

View File

@@ -332,6 +332,8 @@ dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
sequence_len_overflow_handling: drop
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
@@ -505,6 +507,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
@@ -538,7 +541,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing

View File

@@ -3,8 +3,6 @@ title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
@@ -27,7 +25,7 @@ To enable sequence parallelism, add the following to your configuration file:
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
```

View File

@@ -82,6 +82,12 @@ class VllmServeCliArgs:
"hardware support this feature."
},
)
serve_module: Optional[str] = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
},
)
@dataclass

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
@@ -28,6 +27,9 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)

View File

@@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@@ -133,7 +134,7 @@ def load_preference_datasets(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
if cfg.rl is RLType.GRPO:
total_num_steps = None
if cli_args.debug or cfg.debug:

View File

@@ -87,7 +87,7 @@ from axolotl.utils.collators import (
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -353,7 +353,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
if self.cfg.seed is not None:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
@@ -547,8 +547,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
@@ -821,14 +819,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
@@ -1034,6 +1033,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.seed is not None:
training_args_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
@@ -1076,9 +1079,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
@@ -1086,13 +1093,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1106,14 +1113,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
@@ -1156,67 +1163,73 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
trainer_kwargs = {}
if self.cfg.rl is RLType.IPO:
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["kto"]:
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
trainer_kwargs["tokenizer"] = self.tokenizer
else:
dpo_trainer_kwargs["processing_class"] = self.tokenizer
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
**trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return dpo_trainer
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):

View File

@@ -5,7 +5,7 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (

View File

@@ -373,15 +373,13 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
loss = super().compute_loss(
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -1,14 +1,11 @@
"""
DPO Specific Strategy for training
"""
"""DPO Specific Strategy for training"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
from axolotl.utils.schemas.enums import RLType
class DPOStrategy:
"""
Strategy for DPO training
"""
"""Strategy for DPO training"""
@classmethod
def get_trainer_class(cls):
@@ -23,7 +20,7 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None

View File

@@ -1,37 +1,41 @@
"""
GRPO Specific Strategy for training
"""
"""GRPO Specific Strategy for training"""
import importlib
import inspect
import logging
from typing import Any
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
class GRPOStrategy:
"""
Strategy for GRPO training
"""
"""Strategy for GRPO training"""
@classmethod
def get_trainer_class(cls):
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
grpo_args_kwargs: dict[str, Any] = {}
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
@@ -40,8 +44,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:
@@ -102,17 +106,18 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg):
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_args.append(reward_funcs)
return trainer_args
@classmethod
def set_trainer_kwargs(cls, cfg):
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs["reward_processing_classes"] = (
@@ -126,7 +131,7 @@ class GRPOStrategy:
return None
@classmethod
def get_blocklist_args_kwargs(cls):
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc"]
@classmethod
@@ -137,13 +142,13 @@ class GRPOStrategy:
Args:
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
or a HF hub path to the reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
Returns:
RewardFunc: A callable that accepts prompts and completions and returns rewards,
or a path to a reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
"""
try:
# use importlib to dynamically load the reward function from the module

View File

@@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
"""Axolotl GRPO Config for GRPO training"""

View File

@@ -0,0 +1,172 @@
"""Repeat random sampler (similar to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequence parallel group.
"""
from typing import Iterator, Sized
import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with sequence parallelism.
This sampler ensures:
- Ranks in the same sequence parallel (SP) group receive identical data.
- Each index is repeated multiple times for sampling different completions.
- Entire batches are repeated for reuse in multiple updates.
- Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
| SP0 | SP1 |
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
global_step step <---> mini_repeat_count=3
<----------> batch_size=2 per SP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
|
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices
2 5 [4 4 4 5 5 5] [6 6 6 7 7 7]
...
Args:
dataset: Dataset to sample from.
mini_repeat_count: How many times to repeat each sample immediately.
world_size: Total number of processes.
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
"""
def __init__(
self,
dataset: Sized,
mini_repeat_count: int,
world_size: int,
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
self.dataset = dataset
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
self.epoch = 0
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
):
# Drop last incomplete batch if drop_last is True
self.num_samples_per_sp_group = (
self.total_size // self.batch_size // self.num_sp_groups
) * self.batch_size
else:
# Round up to include last batch if drop_last is False
self.num_samples_per_sp_group = (
(self.total_size + self.batch_size * self.num_sp_groups - 1)
// (self.batch_size * self.num_sp_groups)
* self.batch_size
)
if shuffle:
self.generator = torch.Generator()
self.generator.manual_seed(seed)
def __iter__(self) -> Iterator[int]:
"""Creates iterator over dataset indices.
Returns:
Iterator that yields indices into the dataset.
"""
# Deterministically shuffle based on epoch and seed
if self.shuffle:
indices = torch.randperm(
self.num_samples, generator=self.generator
).tolist()
else:
indices = list(range(self.num_samples))
# Add extra samples to make it evenly divisible by batch_size
if len(indices) % self.batch_size != 0:
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size
end_idx = min(start_idx + self.batch_size, len(indices))
if start_idx < len(indices):
for j in range(self.batch_size):
if start_idx + j < end_idx:
batch_indices.append(indices[start_idx + j])
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
if self.drop_last:
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
target_len = self.batch_size * num_batches_per_sp_group
if len(batch_indices) > target_len:
batch_indices = batch_indices[:target_len]
# Apply the GRPO repeat pattern
final_indices = []
for _ in range(self.repeat_count):
for idx in batch_indices:
for _ in range(self.mini_repeat_count):
final_indices.append(idx)
return iter(final_indices)
def __len__(self) -> int:
"""Returns the total length of the iterable including repetitions.
Returns:
Total number of samples.
"""
# Total length including all repetitions
return (
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
)
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
Args:
epoch: Epoch number to use for shuffling.
"""
self.epoch = epoch

View File

@@ -1,23 +1,63 @@
"""
Axolotl GRPO trainer
"""
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from typing import Any
from accelerate.utils import is_deepspeed_available, is_peft_model
import datasets
import torch
import torch.distributed as dist
import torch.utils.data
from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
)
from datasets import Dataset, IterableDataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
Sampler,
)
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@@ -67,3 +107,600 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
def __init__(
self,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: (
Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None
) = None,
processing_class: PreTrainedTokenizerBase | None = None,
reward_processing_classes: (
PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None
) = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
):
# First call the superclass constructor with all arguments
super().__init__(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=reward_processing_classes,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [
n_gen
for n_gen in range(2, sp_group_batch_size + 1)
if (sp_group_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
f"generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
for n_gen in range(2, sp_group_eval_batch_size + 1)
if (sp_group_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}."
)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
def _get_train_sampler(self) -> Sampler:
effective_batch_size = (
self.args.per_device_train_batch_size
* self.world_size
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
world_size=self.world_size,
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
shuffle=True,
seed=self.args.seed,
drop_last=True,
)
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
# Add debug print before any modifications
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
return dataloader
def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
mode = "eval" if self.control.should_evaluate else "train"
prompts = [x["prompt"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
prompt_ids, prompt_mask = (
prompt_inputs["input_ids"],
prompt_inputs["attention_mask"],
)
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
# Calculate sequence parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
* len(prompts_text) : self.num_generations
]
ordered_set_of_prompts.extend(group_prompts)
else:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
completion_ids = completion_ids[process_slice]
else:
# Original behavior for non-sequence parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(
self.model_wrapped,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids,
attention_mask=prompt_mask,
generation_config=self.generation_config,
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.args.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
completion_mask = (
completion_mask * (~truncated_completions).unsqueeze(1).int()
)
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
batch_size = (
self.args.per_device_train_batch_size
if mode == "train"
else self.args.per_device_eval_batch_size
)
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size,
)
else:
old_per_token_logps = None
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size,
)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
completions.append(
[{"role": "assistant", "content": bootstrap + completion}]
)
else:
completions = completions_text
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(
self.reward_funcs,
self.reward_processing_classes,
self.reward_func_names,
)
):
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [
key for key in inputs[0] if key not in ["prompt", "completion"]
]
reward_kwargs = {
key: [example[key] for example in inputs] for key in keys
}
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
# Convert None values to NaN
output_reward_func = [
reward if reward is not None else torch.nan
for reward in output_reward_func
]
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = (
torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
)
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items()
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (
rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
else:
# Original behavior for non-sequence parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
if mode == "train":
self._total_train_tokens += (
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
)
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
# log completion lengths, mean, min, max
agg_completion_mask = self.accelerator.gather_for_metrics(
completion_mask.sum(1)
)
self._metrics[mode]["completions/mean_length"].append(
agg_completion_mask.float().mean().item()
)
self._metrics[mode]["completions/min_length"].append(
agg_completion_mask.float().min().item()
)
self._metrics[mode]["completions/max_length"].append(
agg_completion_mask.float().max().item()
)
# identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_mask) / len(
agg_completion_mask
)
self._metrics[mode]["completions/clipped_ratio"].append(
clipped_completions_ratio
)
if len(term_completion_mask) == 0:
# edge case where no completed sequences are found
term_completion_mask = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(
term_completion_mask.float().mean().item()
)
self._metrics[mode]["completions/min_terminated_length"].append(
term_completion_mask.float().min().item()
)
self._metrics[mode]["completions/max_terminated_length"].append(
term_completion_mask.float().max().item()
)
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
std_rewards = nanstd(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
# Log prompt and completion texts
self._textual_logs["prompt"].extend(gather_object(prompts_text))
self._textual_logs["completion"].extend(gather_object(completions_text))
for i, name in enumerate(self.reward_func_names):
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"old_per_token_logps": old_per_token_logps,
"ref_per_token_logps": ref_per_token_logps,
}

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,85 +1,13 @@
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
"""Module for Axolotl trainer sequence parallelism mixin"""
import functools
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin:
"""
@@ -157,157 +85,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -9,7 +9,7 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass

View File

@@ -4,7 +4,6 @@
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,

View File

@@ -16,11 +16,7 @@ import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
@@ -28,12 +24,12 @@ from transformers.modeling_flash_attention_utils import (
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
from axolotl.utils.schemas.enums import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
RingAttnFunc.BATCH_RING: torch.compile(ring_flash_attn_func),
# RingAttnFunc.BATCH_ZIGZAG: torch.compile(zigzag_ring_flash_attn_func),
# RingAttnFunc.BATCH_STRIPE: torch.compile(stripe_flash_attn_func),
}

View File

@@ -6,13 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
from enum import Enum
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
@@ -41,17 +40,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
@@ -117,11 +105,7 @@ def register_ring_attn(
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,
RingAttnFunc.BATCH_STRIPE,
]:
elif ring_attn_func is RingAttnFunc.BATCH_RING:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)

View File

@@ -7,7 +7,7 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from contextlib import ExitStack
from pathlib import Path
from typing import Any, Dict
@@ -27,14 +27,13 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
try:
@@ -107,7 +106,7 @@ def setup_reference_model(
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.rl and cfg.rl != RLType.ORPO:
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
@@ -188,28 +187,32 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
with ExitStack() as stack:
# Define the context managers to use
if cfg.flash_optimum:
stack.enter_context(
torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
if cfg.sequence_parallel_degree > 1:
models = [trainer.model]
if hasattr(trainer, "ref_model"):
models.append(trainer.ref_model)
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
)
)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -0,0 +1,6 @@
"""Init for context manager submodule"""
# pylint: disable=unused-import
# flake8: noqa
from .sequence_parallel import SequenceParallelContextManager

View File

@@ -0,0 +1,335 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.attention.ring_attn.patch import (
get_ring_attn_group,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, int]:
"""
Apply sequence parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
Returns:
tuple of:
- Batch dictionary with sliced tensors.
- The original sequence length before padding.
- The number of padding tokens added.
"""
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them
batch["position_ids"] = torch.arange(
0,
original_seq_len,
dtype=torch.long,
device=batch["input_ids"].device,
).expand(batch["input_ids"].size(0), -1)
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
logits_to_keep = batch["logits_to_keep"]
# Calculate which positions in the full sequence contain the last N tokens
start_position = max(0, original_seq_len - logits_to_keep)
chunk_size = original_seq_len // local_world_size
rank_start = local_rank * chunk_size
rank_end = rank_start + chunk_size
# Create a boolean mask tensor for this rank's chunk
mask = torch.zeros(
chunk_size,
dtype=torch.bool,
device=batch["input_ids"].device,
)
if rank_end > start_position:
# Calculate how many of the last N tokens fall within this rank's range
tokens_in_rank = min(rank_end, original_seq_len) - max(
rank_start, start_position
)
# Calculate where these tokens start in the local chunk
local_start_idx = max(0, start_position - rank_start)
# Set the appropriate positions in the mask to True
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
# Replace the integer with the boolean mask
batch["logits_to_keep"] = mask
# Add padding to make sequence length divisible by local_world_size
total_seq_len = original_seq_len
pad_len = 0
divisor = min(local_world_size, 64)
if total_seq_len % divisor != 0:
pad_len = divisor - (total_seq_len % divisor)
# Apply padding to all relevant tensors
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
# Create padding tensor
pad_value = -100 if key == "labels" else 0
padding = torch.full(
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
pad_value,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=1)
if key == "logits_to_keep":
# Create padding tensor
padding = torch.ones(
1,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=0)
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for sequence parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
# Split in sequential fashion and grab this rank's chunk
if batch[key].size(1) == total_seq_len:
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif key == "logits_to_keep":
batch[key] = (
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
)
# Handle num_items_in_batch
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
return batch, original_seq_len, pad_len
class SequenceParallelContextManager:
"""Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
"""
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs and get original sequence length and padding info
kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(batch=kwargs)
)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self.gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
if value.size(1) == self.original_seq_len + self.pad_len:
# Slice to remove padding
output[key] = value[:, : self.original_seq_len].contiguous()
return output
# Register both hooks
for model in self.models:
self.hook_handles.append(
model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output
class AllGatherWithGrad(torch.autograd.Function):
"""Custom autograd function for all-gather to preserve gradients."""
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input_tensor: torch.Tensor,
group: dist.ProcessGroup,
) -> torch.Tensor:
"""
Forward pass of all-gather of data with sequence dimension.
Args:
ctx: `torch.autograd` function context.
input_tensor: Tensor from model output with sequence dimension.
group: `torch.distributed` process group.
Returns:
Tensor from gathering the `input_tensor` from across the process group and
concatenating along the sequence dimension.
"""
ctx.group = group
ctx.rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
# Gather shape metadata
local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device)
all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)]
dist.all_gather(all_shapes, local_shape, group=group)
# Store sequence lengths for backward pass
seq_lens = [int(shape[1].item()) for shape in all_shapes]
ctx.seq_lens = seq_lens
# Perform all_gather operation
gathered = [
torch.zeros(
tuple(shape.tolist()),
dtype=input_tensor.dtype,
device=input_tensor.device,
)
for shape in all_shapes
]
dist.all_gather(gathered, input_tensor, group=group)
# Concatenate tensors along sequence dimension
result = torch.cat(gathered, dim=1)
return result
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
) -> tuple[torch.Tensor, None]:
"""
Backward pass for all-gather operation.
Extracts the gradient slice corresponding to this rank's original input
from the full gradient tensor.
Args:
ctx: `torch.autograd` function context.
grad_output: Gradient from subsequent layers with respect to the
concatenated output tensor.
Returns:
Tuple containing the gradient slice for this rank's input tensor and `None`
for the process group parameter which doesn't require gradients.
"""
rank = ctx.rank
seq_lens = ctx.seq_lens
# Extract gradient for this rank's chunk
offset = sum(seq_lens[:rank])
grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous()
return grad_slice, None

View File

@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
@@ -250,6 +251,22 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
"""
Encodes and packs input examples into fixed-length batches for pretraining with optional multipack attention.
Wraps and processes input examples into a dataset, applies sequence packing with configurable overflow handling, and batches the data using a multipack sampler. Each batch is collated and features are aggregated into lists keyed by feature name.
Args:
collate_fn: Function to collate individual feature dictionaries into batch tensors.
ds_wrapper: Callable that wraps a Hugging Face Dataset for further processing.
examples: Dictionary of input examples to encode and pack.
max_seq_length: Maximum sequence length for each packed sequence.
batch_size: Number of sequences to pack per batch.
multipack_attn: If True, enables multipack attention and drops attention masks.
Returns:
Dictionary where each key is a feature name and each value is a list of packed feature tensors.
"""
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
@@ -259,6 +276,10 @@ def encode_packed_pretraining(
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)
sampler = MultipackBatchSampler(

View File

@@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
def _get_path(ds_hash, cfg):
@@ -78,9 +79,34 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
"""
Handles samples exceeding a maximum sequence length for various RL dataset types by either truncating or dropping them.
Depending on the RL type and the `handling` mode, this function either truncates response fields to fit within the specified sequence length or determines whether the sample should be dropped. For DPO, IPO, ORPO, and SIMPO types, both "chosen" and "rejected" responses are considered; for KTO, the "completion" is considered. For GRPO, samples are always retained. If truncation is not possible (e.g., the prompt alone exceeds the limit), the sample is returned unchanged for mapping, or dropped during filtering.
Args:
sample: A dictionary representing a single dataset sample.
rl: The RLType indicating the dataset type.
tokenizer: The tokenizer used to compute token lengths and perform truncation.
sequence_len: The maximum allowed sequence length.
handling: Specifies how to handle overlong sequences ("drop" or "truncate").
Returns:
For "truncate": The modified sample with responses truncated as needed, or the original sample if truncation is not possible.
For "drop": True if the sample fits within the sequence length, otherwise False.
Raises:
ValueError: If required keys are missing for the specified RL type, or if the RL type is unknown.
"""
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -96,11 +122,65 @@ def drop_long_rl_seq(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
# Truncate first, then drop if still invalid (although truncate should handle it)
if handling == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len:
result = sample
else:
# Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt
if rl == "kto":
if max_response_len <= 0:
# Prompt is already too long, behavior depends on handling
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
# Returning the sample might be unexpected. Let's stick to the filter logic
# which would drop this in the `filter` step later if needed.
# For now, return sample to map, or False to filter.
# Let's simplify: truncate *should* result in a valid sample if possible.
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
# However, the filter/map logic is applied *after* this function.
# This function needs to return the *modified* sample for map, or bool for filter.
# Re-think: If handling==truncate, return the modified sample if possible.
# If prompt >= seq_len, modification is impossible. What should map return?
# Maybe return the original sample? But map expects *modified* sample.
# Let's stick to the original logic: if prompt is too long, return False for filter
# and original sample for map.
result = (
sample # For map, let downstream handle it if still invalid?
)
# Or maybe return None/empty dict? Let's return sample for now.
# If handling was drop, filter would remove this.
else:
# Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len:
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len:
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -112,15 +192,54 @@ def drop_long_rl_seq(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
# Truncate first
if handling == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
else:
# Calculate maximum completion length
max_completion_len = sequence_len - len_prompt
if rl == "grpo":
return True
if max_completion_len <= 0:
# Prompt too long, return sample for map
result = sample
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
completion_tokens = tokenizer(
completion, add_special_tokens=False
)["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
elif rl == RLType.GRPO:
# GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
else:
raise ValueError("Unknown RL type")
return result
def load_prepare_preference_datasets(cfg):
"""
Loads, preprocesses, and prepares preference datasets for RL training and evaluation.
This function orchestrates the loading, transformation, sequence length handling, optional deduplication, and caching of datasets for Direct Preference Optimization (DPO) and related RL types. It supports configurable handling of overlong sequences (dropping or truncating), applies dataset-specific transformations, and manages train/validation/test splits as needed.
Args:
cfg: Configuration object specifying dataset sources, RL type, tokenizer, sequence length, and processing options.
Returns:
A tuple containing the prepared training and evaluation datasets.
"""
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
@@ -137,9 +256,9 @@ def load_prepare_preference_datasets(cfg):
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl == "orpo":
if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl == "kto":
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
@@ -150,7 +269,7 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl == "kto":
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
@@ -164,28 +283,46 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
# Determine handling mode
handling = cfg.get("sequence_len_overflow_handling", "drop")
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
handling=handling, # Pass the handling mode
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Note: Length might not change if truncation always occurs
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter(
drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
@@ -205,6 +342,8 @@ def load_prepare_preference_datasets(cfg):
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
@@ -213,7 +352,7 @@ def load_prepare_preference_datasets(cfg):
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
+ str(seed)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
@@ -222,13 +361,13 @@ def load_prepare_preference_datasets(cfg):
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
+ str(seed)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=cfg.seed,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,

View File

@@ -148,7 +148,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
seed=cfg.seed if cfg.seed is not None else 42,
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -416,6 +416,8 @@ def load_prepare_datasets(
)
if split == "train" and val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
@@ -424,7 +426,7 @@ def load_prepare_datasets(
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
+ str(seed)
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
@@ -433,7 +435,7 @@ def load_prepare_datasets(
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
+ str(seed)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
@@ -442,7 +444,7 @@ def load_prepare_datasets(
dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
seed=cfg.seed or 42,
seed=seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)

View File

@@ -13,10 +13,12 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = logging.getLogger(__name__)
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
class RetryStrategy(Enum):
"""
@@ -159,16 +161,33 @@ def deduplicate_and_log_datasets(
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
"""
Processes a dataset to handle sequences exceeding a configured maximum length by either truncating or dropping them.
If the dataset lacks an "input_ids" column, the function returns the dataset unchanged. The handling mode is determined by the configuration parameter "sequence_len_overflow_handling", defaulting to "drop". In "truncate" mode, sequences longer than the maximum length are truncated; in "drop" mode, such sequences are removed from the dataset. The function logs information about sequence lengths and the number of samples affected when applicable.
Args:
dataset: The Huggingface Dataset to process.
cfg: Configuration object specifying sequence length parameters and handling mode.
Returns:
The processed dataset with long sequences either truncated or dropped according to the configuration.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
# Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("sequence_len_overflow_handling", "drop")
# Use the new function with the specified handling mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
try:
@@ -193,17 +212,31 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
if handling == "truncate":
drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode
dataset = dataset.filter(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -5,8 +5,11 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
Disco,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Disco.apply(
decoder_layer,
*args,
)
return Disco.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,4 +1,4 @@
"""Unsloth checkpointing"""
"""CPU offloaded checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -0,0 +1,531 @@
"""
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
"""
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
import tempfile
import threading
import time
import uuid
from collections import deque
from concurrent.futures import Future
from typing import Dict
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
class DiskOffloadManager:
"""
Manages offloaded tensors and handles prefetching in a separate thread.
Includes synchronization to prevent race conditions.
"""
def __init__(
self,
prefetch_size: int = 3,
prefetch_to_gpu: bool = True,
save_workers: int = 4,
):
"""
Args:
prefetch_size: Maximum number of tensors to prefetch in the background.
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
save_workers: Maximum number of concurrent save operations.
"""
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
# Track tensor paths and their status
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
self.file_locks: Dict[str, threading.Lock] = (
{}
) # Maps file_path -> threading.Lock()
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
self.file_status: Dict[str, str] = {}
self.max_prefetch = prefetch_size
self.prefetch_to_gpu = prefetch_to_gpu
# Thread synchronization
self.manager_lock = threading.RLock() # Used for thread-safe operations
# Prefetch queue and cache
self.prefetch_queue: queue.Queue = queue.Queue()
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
# Save queue and thread pool
self.save_queue: queue.Queue = queue.Queue()
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
self.save_futures: Dict[str, Future] = {}
self.save_semaphore = threading.Semaphore(
save_workers * 2
) # Limit concurrent save operations
# Start prefetch worker thread
self.stop_event = threading.Event()
# start multiple threads for prefetching
self.prefetch_worker_count = 2
self.prefetch_workers = []
for _ in range(self.prefetch_worker_count):
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
worker.start()
self.prefetch_workers.append(worker)
# Start save worker thread
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
self.save_worker.start()
self.idx = 0
atexit.register(self.cleanup)
def _save_worker(self):
"""Background thread that processes the save queue"""
while not self.stop_event.is_set():
try:
save_item = self.save_queue.get(timeout=0.5)
if save_item is None:
continue
tensor, file_path = save_item
# Submit the save task to the thread pool
future = self.save_pool.submit(
self._save_tensor_to_disk, tensor, file_path
)
with self.manager_lock:
self.save_futures[file_path] = future
self.save_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
"""Actually save the tensor to disk"""
try:
# Save tensor to disk
cpu_tensor = tensor.detach().cpu()
torch.save(cpu_tensor, file_path)
del cpu_tensor
with self.manager_lock:
# Mark file as ready
self.file_status[file_path] = "ready"
# Release semaphore
self.save_semaphore.release()
return True
except FileNotFoundError as e:
logger.error(f"Error saving tensor to {file_path}: {e}")
with self.manager_lock:
self.file_status[file_path] = "error"
# Release semaphore
self.save_semaphore.release()
return False
def _prefetch_worker(self):
"""Background thread that loads tensors from disk ahead of time"""
while not self.stop_event.is_set():
try:
file_path = self.prefetch_queue.get(timeout=0.5)
if file_path is None:
continue
# Check if file is available and not already in cache
with self.manager_lock:
if (
file_path not in self.file_status
or self.file_status[file_path] == "deleted"
):
self.prefetch_queue.task_done()
if file_path in self.prefetch_cache:
self.prefetch_queue.task_done()
continue
# If file is still being saved, wait for it
if (
self.file_status[file_path] == "saving"
and file_path in self.save_futures
):
# Re-queue this prefetch request with a little delay
self.prefetch_queue.task_done()
time.sleep(0.1)
self.prefetch_queue.put(file_path)
continue
# Mark file as being prefetched
self.file_status[file_path] = "prefetching"
# Load tensor from disk and store in cache
try:
if os.path.exists(file_path):
if self.prefetch_to_gpu:
tensor = torch.load(
file_path,
map_location=torch.device("cuda"),
weights_only=True,
)
else:
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.prefetch_cache[file_path] = tensor
self.file_status[file_path] = "ready"
else:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(
f"Prefetch error: File not found {file_path}"
)
self.file_status[file_path] = "missing"
except FileNotFoundError as e:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(f"Prefetch error for {file_path}: {e}")
self.file_status[file_path] = "error"
self.prefetch_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def save_tensor(self, tensor: torch.Tensor):
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
# Generate unique file path
self.idx += 1
file_path: str = os.path.join(
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
)
with self.manager_lock:
# Mark file as being saved
self.file_locks[file_path] = threading.Lock()
self.file_status[file_path] = "saving"
# Add to history
self.tensor_paths.append(file_path)
# Acquire semaphore to limit concurrent save operations
self.save_semaphore.acquire() # pylint: disable=consider-using-with
# Queue tensor for saving in background
self.save_queue.put((tensor.detach(), file_path))
return file_path
def wait_for_save(self, file_path, timeout=None) -> None:
"""Wait for a tensor to be saved to disk"""
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
with self.manager_lock:
if self.file_status.get(file_path) == "ready":
return
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
return
if file_path in self.save_futures:
future = self.save_futures[file_path]
if future.done():
return
# Small sleep to prevent CPU spinning
time.sleep(0.01)
# Timeout
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
return
def load_tensor(self, file_path, target_device="cuda"):
"""Load tensor from disk or prefetch cache with proper synchronization"""
# Wait for tensor to be saved if it's still in progress
self.wait_for_save(file_path)
tensor = None
# Try to get from cache first
with self.manager_lock:
# Check if tensor is already in cache
if file_path in self.prefetch_cache:
tensor = self.prefetch_cache[file_path]
del self.prefetch_cache[file_path]
self.file_status[file_path] = "loaded"
if tensor is not None:
# Ensure tensor is on correct device
if target_device != "cpu" and tensor.device.type == "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
# If not in cache, load directly from disk
try:
if not os.path.exists(file_path):
logger.error(f"File not found for loading: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.file_status[file_path] = "loaded"
if target_device != "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
except Exception as e:
logger.error(f"Error loading tensor from {file_path}: {e}")
raise
def _safe_delete_file(self, file_path):
"""Safely delete a file with proper synchronization"""
with self.manager_lock:
# Make sure any save operation is completed
if file_path in self.save_futures:
future = self.save_futures[file_path]
try:
if not future.done():
future.cancel()
del self.save_futures[file_path]
except FileNotFoundError as e:
logger.warning(
f"Error canceling save operation for {file_path}: {e}"
)
# Only delete if file exists and is not being prefetched
status = self.file_status.get(file_path)
if status in ["ready", "loaded", "error", "missing"]:
try:
if os.path.exists(file_path):
os.remove(file_path)
self.file_status[file_path] = "deleted"
return True
except FileNotFoundError as e:
logger.warning(f"Error deleting file {file_path}: {e}")
return False
def trigger_prefetch(self, n=None):
"""Trigger prefetching of the next N tensors with proper synchronization"""
if n is None:
n = self.max_prefetch
prefetch_paths = []
with self.manager_lock:
# Find files that are ready to be prefetched (not already in cache or being prefetched)
for path in reversed(self.tensor_paths):
if (
path not in self.prefetch_cache
and self.file_status.get(path) == "ready"
):
prefetch_paths.append(path)
if len(prefetch_paths) >= n:
break
# Queue files for prefetching
for path in prefetch_paths:
self.prefetch_queue.put(path)
def cleanup_tensor(self, file_path: str):
"""Clean up a specific tensor file after it's been used"""
with self.manager_lock:
if file_path in self.tensor_paths:
self.tensor_paths.remove(file_path)
# Remove from prefetch cache if present
if file_path in self.prefetch_cache:
del self.prefetch_cache[file_path]
# Remove from save futures if present
if file_path in self.save_futures:
future = self.save_futures[file_path]
if not future.done():
future.cancel()
del self.save_futures[file_path]
# Try to delete the file
self._safe_delete_file(file_path)
def cleanup(self):
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
self.stop_event.set()
# Cancel all pending save operations
with self.manager_lock:
for _, future in self.save_futures.items():
if not future.done():
future.cancel()
self.save_futures.clear()
# Drain the save queue
while not self.save_queue.empty():
try:
self.save_queue.get_nowait()
self.save_queue.task_done()
except queue.Empty:
break
# Shutdown the save pool
self.save_pool.shutdown(wait=False)
# Join the save worker thread
if self.save_worker.is_alive():
self.save_worker.join(timeout=2.0)
# Join the prefetch worker threads
for thread in self.prefetch_workers:
if thread.is_alive():
thread.join(timeout=2.0)
# Clear cache and remove all temporary files
with self.manager_lock:
self.prefetch_cache.clear()
paths_to_delete = list(self.tensor_paths)
self.tensor_paths.clear()
# Delete all temporary files
for path in paths_to_delete:
self._safe_delete_file(path)
# Remove temp directory
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except FileNotFoundError as e:
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
class Disco(torch.autograd.Function):
"""
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
Advanced disk-based gradient checkpointer with prefetching.
"""
# Shared manager instance across all checkpointing operations
_manager = None
@staticmethod
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
"""Get or create the offload manager"""
if Disco._manager is None:
Disco._manager = DiskOffloadManager(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
return Disco._manager
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(
ctx,
forward_function,
hidden_states,
*args,
prefetch_size=1,
prefetch_to_gpu=True,
save_workers=4,
):
"""Forward pass that offloads activations to disk asynchronously"""
# Get or create the manager
manager = Disco.get_instance(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
# Save tensor to disk asynchronously
file_path = manager.save_tensor(hidden_states)
# Run forward pass immediately without waiting for save to complete
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store what we need for backward
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, *grad_outputs):
"""Backward pass that loads activations from disk with prefetching"""
# Get the manager
manager = Disco._manager
# Trigger prefetching for future tensors
# This happens at the start of backward, so should have time to complete
manager.trigger_prefetch()
# Load hidden states from disk or prefetch cache
file_path = ctx.file_path
try:
# Ensure the file is saved before we try to load it
manager.wait_for_save(file_path)
hidden_states = manager.load_tensor(file_path)
hidden_states.requires_grad = True
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Handle tuple outputs properly
if isinstance(output, tuple):
if len(grad_outputs) == len(output):
torch.autograd.backward(output, grad_outputs)
else:
torch.autograd.backward(output, grad_outputs[0])
else:
torch.autograd.backward(output, grad_outputs[0])
# Clean up the file after we're done with it
manager.cleanup_tensor(file_path)
return (
(
None, # forward_function
hidden_states.grad, # hidden_states grad
)
+ (None,) * len(ctx.args) # for each arg
+ (
None, # prefetch_size
None, # prefetch_to_gpu
None, # save_workers
)
)
except Exception as e:
logger.error(f"Error in backward pass: {e}")
# Clean up the file even on error
manager.cleanup_tensor(file_path)
raise

View File

@@ -70,9 +70,13 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
@@ -619,6 +623,10 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()
@@ -1372,7 +1380,7 @@ class ModelLoader:
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
self.cfg.adapter
and self.cfg.rl in ["dpo", "ipo", "kto"]
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(

View File

@@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RLType
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -178,7 +178,7 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
@@ -186,6 +186,12 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(default=512)
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop",
json_schema_extra={
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
},
)
min_sample_len: int | None = None
max_prompt_len: int = Field(
default=512,
@@ -260,7 +266,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: str | None = None
ring_attn_func: RingAttnFunc | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -782,7 +788,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
if self.rl is RLType.SIMPO and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
@@ -1149,16 +1155,28 @@ class AxolotlInputConfig(
return data
# @model_validator(mode="before")
# @classmethod
# def check_grpo_peft_liger(cls, data):
# if (
# data.get("rl") == "grpo"
# and data.get("trl", {})
# and data.get("trl").get("use_liger_loss")
# and data.get("adapter")
# ):
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
# return data
#
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
def check_grpo_liger_sequence_parallel(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
and data.get("sequence_parallel_degree", 1) > 1
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
raise ValueError("GRPO + SP + Liger not currently supported")
return data
@model_validator(mode="after")
@@ -1173,7 +1191,7 @@ class AxolotlInputConfig(
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
@@ -1205,16 +1223,8 @@ class AxolotlInputConfig(
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
raise ValueError(
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
)
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
# Default ring attention function selection
sample_packing = getattr(self, "sample_packing", False)
@@ -1345,6 +1355,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1

View File

@@ -6,12 +6,12 @@ from enum import Enum
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -55,3 +55,14 @@ class CustomSupportedOptimizers(str, Enum):
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
# BATCH_ZIGZAG = "batch_zigzag"
# BATCH_STRIPE = "batch_stripe"

View File

@@ -207,10 +207,18 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
Determines whether samples should be kept based on sequence length constraints.
For a single example or a batch, returns True (or a list of booleans) if each sequence's length is within the specified range; otherwise, returns False (or a list with False for out-of-range sequences).
Args:
sample: A dictionary containing "input_ids" as a list of ints or a list of lists of ints.
sequence_len: Maximum allowed sequence length (inclusive).
min_sequence_len: Minimum allowed sequence length (inclusive).
Returns:
True if the single example is within the length range, False otherwise.
For batched input, returns a list of booleans indicating which sequences are within the range.
"""
min_sequence_len = min_sequence_len or 2
@@ -235,7 +243,121 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return results
def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Drops or truncates samples based on sequence length constraints.
If handling is "drop", returns a boolean or list of booleans indicating whether each sample's sequence length is within the specified range. If handling is "truncate", returns the sample with sequences longer than sequence_len truncated and sequences shorter than min_sequence_len omitted. Supports both single-example and batched inputs.
Args:
sample: A dictionary containing at least an "input_ids" field, representing either a single sequence or a batch of sequences.
sequence_len: Maximum allowed sequence length.
min_sequence_len: Minimum allowed sequence length.
handling: "drop" to filter out samples outside the range, "truncate" to truncate long sequences.
Returns:
In "drop" mode, a boolean or list of booleans indicating which samples to keep. In "truncate" mode, the modified sample with sequences truncated as needed.
"""
min_sequence_len = min_sequence_len or 2
result = None
if handling == "drop":
return drop_long_seq(sample, sequence_len, min_sequence_len)
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
result = False if handling == "drop" else sample
# Single example (input_ids is a list of int)
elif isinstance(input_ids[0], int):
length = len(input_ids)
# Handle samples that are too short - always drop them
if length < min_sequence_len:
result = False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it
elif length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
# Also truncate labels if present
if "labels" in sample:
sample["labels"] = sample["labels"][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"] = sample["position_ids"][:sequence_len]
# Update length if present
if "length" in sample:
sample["length"] = sequence_len
result = sample
# For drop mode or if the sample doesn't exceed max length
else:
result = (
min_sequence_len <= length <= sequence_len
if handling == "drop"
else sample
)
# Batched (input_ids is a list of lists)
else:
if handling == "drop":
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
result = results
else: # truncate
# Check each sequence in the batch
for i, seq in enumerate(input_ids):
length = len(seq)
# Skip sequences that are too short
if length < min_sequence_len:
continue
# Truncate sequences that are too long
if length > sequence_len:
input_ids[i] = seq[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][
:sequence_len
]
# Also truncate labels if present
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][
:sequence_len
]
# Update length if present
if "length" in sample:
sample["length"][i] = sequence_len
result = sample
return result
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
Prepares training and evaluation datasets for sample packing and model-specific requirements.
Removes unnecessary columns based on model type, filters out samples with no trainable tokens, and optionally adds length or position ID columns for sample packing or PoSE techniques. Returns the processed training and evaluation datasets.
"""
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
@@ -370,15 +492,48 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
train_dataset,
sequence_len,
skip_position_ids=True,
drop_attention_mask=False,
handling="drop",
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
# Define the function to use for handling sequences based on the mode
"""
Processes a pretraining dataset by truncating or dropping sequences based on length.
Depending on the handling mode, sequences longer than `sequence_len` are either truncated or dropped, and sequences shorter than `min_sequence_len` are dropped. Optionally adds position IDs and removes the attention mask column.
Args:
train_dataset: The dataset to process.
sequence_len: Maximum allowed sequence length.
skip_position_ids: If False, adds position IDs to each sample.
drop_attention_mask: If True, removes the attention mask column.
handling: "drop" to remove long sequences, "truncate" to truncate them.
Returns:
The processed dataset with sequences handled according to the specified mode.
"""
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
train_dataset = train_dataset.map(
seq_handler_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False,
)
else: # handling == "drop"
train_dataset = train_dataset.filter(
seq_handler_fn, # Use the same function, it returns boolean for drop mode
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,

View File

@@ -25,6 +25,7 @@ class TestSequenceParallelism:
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
@@ -93,22 +94,22 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
@@ -118,6 +119,7 @@ class TestSequenceParallelism:
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
@@ -126,4 +128,5 @@ class TestSequenceParallelism:
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
threshold=threshold,
)

View File

@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
@@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},

View File

@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
E2E tests for activation checkpointing
"""
@pytest.mark.parametrize(
"gradient_checkpointing",
["offload", "offload_disk"],
)
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
"gradient_checkpointing": gradient_checkpointing,
}
)

View File

@@ -10,14 +10,15 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
@@ -62,12 +63,14 @@ def sequence_parallel_batch():
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
@@ -179,12 +182,44 @@ class TestConfigValidation:
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
@@ -256,7 +291,7 @@ class TestConfigValidation:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
@@ -290,10 +325,11 @@ class TestApplySequenceParallelism:
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
@@ -305,10 +341,11 @@ class TestApplySequenceParallelism:
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
@@ -328,57 +365,59 @@ class TestApplySequenceParallelism:
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
@@ -390,11 +429,12 @@ class TestApplySequenceParallelism:
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
@@ -412,13 +452,15 @@ class TestApplySequenceParallelism:
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
"""
import unittest
from unittest.mock import MagicMock
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data.rl import drop_long_rl_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -58,11 +60,328 @@ class TestEncodePretraining(unittest.TestCase):
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
"""
Tests that the md5 function returns the correct hash for a given string and encoding.
"""
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
class TestDropLongRLSeq(unittest.TestCase):
"""
Tests for the drop_long_rl_seq function.
"""
def setUp(self):
# Mock tokenizer that returns length based on input string length
"""
Sets up a mock tokenizer and sequence length for RL sequence length tests.
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
"""
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
"""
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
Args:
text: The input string to tokenize.
add_special_tokens: Ignored parameter included for interface compatibility.
Returns:
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
"""
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
["x"] * len(tokens)
) # pylint: disable=unused-argument
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
"rejected": "r" * 6,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 16,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(
result, original_sample
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
# Even though truncation isn't possible, the function should return the original sample
# for the map operation, assuming downstream filtering will catch it.
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 18,
"rejected": "r" * 10,
} # 5+18=23 > 20, 5+10=15 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["chosen"], "x" * max_resp_len
) # Check decoded truncated value
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 10,
"rejected": "r" * 18,
} # 5+10=15 <= 20, 5+18=23 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), 10) # Unchanged
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["rejected"], "x" * max_resp_len
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
"""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 15,
"rejected": "r" * 14,
} # 8+15=23 > 20, 8+14=22 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""
Tests DPO truncate mode where only the overlong response is truncated.
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
"""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
# Let's test a case where one is slightly too long causing the initial fail,
# but the other fits *within* the max_response_len, so only one gets truncated.
prompt_len = 10
max_resp_len = self.sequence_len - prompt_len # 10
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 11,
"rejected": "r" * 9,
} # 10+11=21 > 20, 10+9=19 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
"""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""
Tests that a ValueError is raised when required keys are missing for DPO samples.
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""
Tests that a ValueError is raised when required keys are missing for RL type "kto".
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
a ValueError with the expected error message.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
"""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
"""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_grpo_truncate(self):
"""
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
"""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, sample)
if __name__ == "__main__":
unittest.main()

175
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,175 @@
"""Module containing tests for trainer utility functions."""
import unittest
from functools import partial
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
"""
Sets up default sequence length parameters for the test cases.
"""
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""
Verifies that 'drop' mode correctly filters single sequence examples based on length.
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
while sequences within the valid length range are kept.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {
"input_ids": list(range(self.min_sequence_len)),
"labels": list(range(self.min_sequence_len)),
}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""
Tests that batched examples are correctly truncated in "truncate" mode.
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
allowed length are truncated, while sequences that are too short or empty remain
unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()