Compare commits

...

3 Commits

Author SHA1 Message Date
Dan Saunders
9f30d3d33a reworking SP logic into composed handler 2025-04-04 02:25:00 +00:00
Dan Saunders
ce07081d6c doc updates; config fix 2025-04-01 20:35:10 +00:00
Dan Saunders
3ce43b6db9 simplifying trainer mixins and adding to rl trainers 2025-04-01 17:53:12 +00:00
19 changed files with 402 additions and 187 deletions

View File

@@ -686,9 +686,10 @@ ddp_broadcast_buffers:
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized # E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences. # subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. # See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree: sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. flash_attention: true # SP requires flash attention
# Must evenly divide the number of KV heads in your model. micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1 heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'

View File

@@ -23,9 +23,10 @@ Use sequence parallelism when:
To enable sequence parallelism, add the following to your configuration file: To enable sequence parallelism, add the following to your configuration file:
```yaml ```yaml
# Set to a divisor (> 1) of the number of GPUs available sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
sequence_parallel_degree: 4 # Split sequences across 4 GPUs flash_attention: true # SP requires flash attention
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1 heads_k_stride: 1
``` ```
@@ -66,15 +67,16 @@ sequence_len: 8192
... ...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism flash_attention: true # SP requires flash attention
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1 heads_k_stride: 1
... ...
``` ```
This will train the Llama 3 8B model with 8K context length, with each sequence split This will train the Llama 3 8B model with 8192 context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs. into 4 subsequences of length 2048 across 4 GPUs.
## Sample Packing with Sequence Parallelism ## Sample Packing with Sequence Parallelism
@@ -86,12 +88,14 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size ## Effect on Batch Size
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases - The number of batches processed per step decreases
For example: For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) - With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 - If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2

View File

@@ -82,3 +82,6 @@ deepspeed:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None training_args_cls = None
blocklist_args_kwargs = [] blocklist_args_kwargs = []
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
@@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["dataset_tags"] = [ dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
] ]
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(
*trainer_cls_args, *trainer_cls_args,
args=training_args, args=training_args,
@@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)
return dpo_trainer return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer from axolotl.core.trainers.relora import ReLoRATrainer
from .trl import ( from axolotl.core.trainers.trl import (
AxolotlCPOTrainer, AxolotlCPOTrainer,
AxolotlKTOTrainer, AxolotlKTOTrainer,
AxolotlORPOTrainer, AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer, AxolotlPRMTrainer,
AxolotlRewardTrainer, AxolotlRewardTrainer,
TRLPPOTrainer,
) )

View File

@@ -8,10 +8,11 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from typing import Literal from typing import Any, Literal
import datasets import datasets
import torch import torch
import torch.nn as nn
from datasets import Dataset from datasets import Dataset
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
@@ -25,12 +26,8 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.handlers import SequenceParallelHandler
OptimizerMixin, from axolotl.core.trainers.mixins import TrainerMixins
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
@@ -40,9 +37,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class AxolotlTrainer( class AxolotlTrainer(TrainerMixins, Trainer):
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers""" """Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -68,9 +63,7 @@ class AxolotlTrainer(
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled self.sequence_parallel_handler = SequenceParallelHandler(self.args)
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile: if self.args.torch_compile:
@@ -131,7 +124,7 @@ class AxolotlTrainer(
# Determine the base sampler first # Determine the base sampler first
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset) base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling: elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset) base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing: elif use_sample_packing:
@@ -167,7 +160,7 @@ class AxolotlTrainer(
# Determine the base sampler # Determine the base sampler
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset) base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
elif use_multipack: elif use_multipack:
base_sampler = SequentialSampler(eval_dataset) base_sampler = SequentialSampler(eval_dataset)
else: else:
@@ -239,7 +232,10 @@ class AxolotlTrainer(
return dataloader return dataloader
# Otherwise prepare with accelerator # Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader) dataloader = self.accelerator.prepare_data_loader(dataloader)
return dataloader
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training""" """Get dataloader for training"""
@@ -348,7 +344,57 @@ class AxolotlTrainer(
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping of inputs.
num_items_in_batch: The number of items in the batch.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
@override @override
def compute_loss( def compute_loss(

View File

@@ -1,14 +1,10 @@
""" """DPO Specific Strategy for training"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy: class DPOStrategy:
""" """Strategy for DPO training"""
Strategy for DPO training
"""
@classmethod @classmethod
def get_trainer_class(cls): def get_trainer_class(cls):

View File

@@ -1,6 +1,4 @@
""" """Axolotl specific DPO args"""
Axolotl specific DPO args
"""
from dataclasses import dataclass from dataclasses import dataclass
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass @dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
""" """DPO config for DPO training"""
DPO config for DPO training
"""

View File

@@ -1,10 +1,7 @@
""" """DPO trainer for axolotl"""
DPO trainer for axolotl
"""
import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any
import torch import torch
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
@@ -13,7 +10,8 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
@@ -23,18 +21,18 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
""" """Extend the base DPOTrainer for axolotl helpers"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"] tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs): def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self.optimizer = None self.optimizer = None
self.model_accepts_loss_kwargs = False self.model_accepts_loss_kwargs = False
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
def create_optimizer(self): def create_optimizer(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -88,7 +86,7 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
max_prompt_length, max_prompt_length,
max_completion_length, max_completion_length,
add_special_tokens, add_special_tokens,
) -> Dict: ) -> dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
@@ -117,10 +115,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
def training_step( def training_step(
self, self,
model: nn.Module, model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: dict[str, torch.Tensor | Any | None],
num_items_in_batch=None, num_items_in_batch=None,
) -> torch.Tensor: ) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
gc.collect()
torch.cuda.empty_cache() return super().training_step(model, inputs, num_items_in_batch)
return loss

View File

@@ -1,6 +1,4 @@
""" """Axolotl GRPO trainer"""
Axolotl GRPO trainer
"""
from contextlib import nullcontext from contextlib import nullcontext
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import TrainerMixins
if is_deepspeed_available(): if is_deepspeed_available():
import deepspeed import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
""" """Extend the base GRPOTrainer for axolotl helpers"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -0,0 +1,3 @@
"""Init for trainer handlers"""
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler

View File

@@ -0,0 +1,123 @@
"""Handler class for sequence parallel trainer logic"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DistributedSampler
class SequenceParallelHandler:
"""
Handler class that encapsulates sequence parallelism functionality.
This replaces the SequenceParallelMixin with a composition-based approach.
"""
def __init__(self, args=None):
"""
Initialize the sequence parallel handler.
Args:
args: The arguments object containing sequence parallelism settings.
"""
self.args = args
self.ring_attn_group = None
# Set up sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
from ring_flash_attn import update_ring_flash_attn_params
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.update_ring_flash_attn_params = update_ring_flash_attn_params
self.ring_attn_group = get_ring_attn_group()
def create_sequence_parallel_sampler(
self,
dataset,
shuffle=True,
is_eval=False,
):
"""
Helper method to create sampler for sequence parallelism (SP).
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset):
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset):
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -3,7 +3,12 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .optimizer import OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin
):
"""Stub class combining all mixins for Axolotl trainers."""

View File

@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer): class RngLoaderMixin(Trainer):
""" """Mixin for method override to load RNG states from a checkpoint"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint): def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint` # Load RNG states from `checkpoint`

View File

@@ -1,4 +1,5 @@
"""Module for Axolotl trainer sequence parallelism mixin""" """Module for Axolotl trainer sequence parallelism mixin"""
# TODO(Dan): remove
import logging import logging
from typing import Any from typing import Any
@@ -7,7 +8,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
drop_last=not is_eval, drop_last=not is_eval,
) )
def _sp_get_train_sampler(self, dataset) -> Sampler | None: def _get_train_sampler(self, dataset) -> Sampler | None:
""" """
Get a training sampler configured for sequence parallelism. Get a training sampler configured for sequence parallelism.
Args: Args:
dataset: The training dataset dataset: The training dataset.
Returns: Returns:
Configured sequence parallel sampler. Configured sequence parallel sampler.
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
shuffle=not self.args.curriculum_sampling, shuffle=not self.args.curriculum_sampling,
) )
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
""" """
Get an evaluation sampler configured for sequence parallelism. Get an evaluation sampler configured for sequence parallelism.
@@ -130,53 +130,3 @@ class SequenceParallelMixin:
) )
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -13,11 +13,10 @@ from trl import (
RewardTrainer, RewardTrainer,
) )
from axolotl.core.trainers.mixins import RngLoaderMixin from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer): class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations""" """Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"] tag_names = ["axolotl", "ppo"]
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
) )
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
""" """Extend the base ORPOTrainer for axolotl helpers"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
return loss, metrics return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
""" """Extend the base KTOTrainer for axolotl helpers"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
""" """Extend the base CPOTrainer for axolotl helpers"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
return loss, metrics return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
""" """Extend the base RewardTrainer for axolotl helpers"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer): class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
""" """Extend the base trl.PRMTrainer for axolotl helpers"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"] tag_names = ["axolotl", "prm"]

View File

@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingMixins:
""" """Mixin class for the Axolotl training args."""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
model_type: Optional[str] = field( model_type: Optional[str] = field(

View File

@@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2. their sequence parallel version of Flash Attention 2.
""" """
import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger from accelerate.logging import get_logger
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
Setter for ring attention group on this rank. Setter for ring attention group on this rank.
Args: Args:
Process group for ring attention. ring_attn_group: Process group for ring attention.
""" """
global RING_ATTN_GROUP # pylint: disable=global-statement global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
"""
Patch flash attention a second time to handle batched data. This is a hack to
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
specified in the Axolotl config.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
# Store the original flash attention function
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
def sequential_batch_flash_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
sliding_window: int | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# Check if we have a batch dimension > 1
batch_size = query.shape[0]
if batch_size <= 1:
return original_flash_attention(
module,
query,
key,
value,
attention_mask,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
# Process each item in the batch separately
outputs = []
for i in range(batch_size):
# Extract single batch item
q_item = query[i:i+1]
k_item = key[i:i+1]
v_item = value[i:i+1]
# Handle attention mask - it might be None or have different shapes
mask_item = None
if attention_mask is not None:
# The mask could have different formats depending on implementation
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
mask_item = attention_mask[i:i+1]
else:
# For broadcast masks that don't have a batch dimension
mask_item = attention_mask
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = q_item.shape[0]
seq_len = q_item.shape[2]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
# Call the original function for a single batch item
output, _ = original_flash_attention(
module,
q_item,
k_item,
v_item,
mask_item,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
outputs.append(output)
dist.barrier()
# Concatenate results along batch dimension
concatenated_output = torch.cat(outputs, dim=0)
return concatenated_output, None
# Replace the original function with our sequential version
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
@@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn( substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
) )
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)

View File

@@ -1351,9 +1351,7 @@ def load_model(
reference_model: bool = False, reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """Load a model for a given configuration and tokenizer."""
Load a model for a given configuration and tokenizer.
"""
model_loader = ModelLoader( model_loader = ModelLoader(
cfg, cfg,
tokenizer, tokenizer,
@@ -1362,12 +1360,16 @@ def load_model(
reference_model=reference_model, reference_model=reference_model,
**kwargs, **kwargs,
) )
return model_loader.load_model() return model_loader.load_model()
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel, PeftConfig | None]:
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
@@ -1380,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False):
raise NotImplementedError(f"{adapter} peft adapter not available") raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg): def load_llama_adapter(
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] model: PreTrainedModel, cfg: DictDefault
) -> tuple[PreTrainedModel, PeftConfig | None]:
from peft import AdaptionPromptConfig, get_peft_model from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig( peft_config = AdaptionPromptConfig(
@@ -1405,7 +1408,7 @@ def load_llama_adapter(model, cfg):
return model, peft_config return model, peft_config
def find_all_linear_names(model): def find_all_linear_names(model: PreTrainedModel):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():