Compare commits
3 Commits
coderabbit
...
sp-rl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f30d3d33a | ||
|
|
ce07081d6c | ||
|
|
3ce43b6db9 |
@@ -686,9 +686,10 @@ ddp_broadcast_buffers:
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
|
||||
@@ -23,9 +23,10 @@ Use sequence parallelism when:
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
@@ -66,15 +67,16 @@ sequence_len: 8192
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
||||
into 4 subsequences of length 2048 across 4 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
@@ -86,12 +88,14 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
||||
|
||||
@@ -82,3 +82,6 @@ deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
@@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
@@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
@@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
return dpo_trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# build PPOConfig
|
||||
pass
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||
from axolotl.core.trainers.trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -8,10 +8,11 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
@@ -25,12 +26,8 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -40,9 +37,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||
):
|
||||
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -68,9 +63,7 @@ class AxolotlTrainer(
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
@@ -131,7 +124,7 @@ class AxolotlTrainer(
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
@@ -167,7 +160,7 @@ class AxolotlTrainer(
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
@@ -239,7 +232,10 @@ class AxolotlTrainer(
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
@@ -348,7 +344,57 @@ class AxolotlTrainer(
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
num_items_in_batch: The number of items in the batch.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
"""DPO Specific Strategy for training"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
"""Strategy for DPO training"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
"""Axolotl specific DPO args"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
"""DPO config for DPO training"""
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
"""DPO trainer for axolotl"""
|
||||
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
@@ -13,7 +10,8 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -23,18 +21,18 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -88,7 +86,7 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
@@ -117,10 +115,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
inputs: dict[str, torch.Tensor | Any | None],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
||||
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
"""Axolotl GRPO trainer"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Init for trainer handlers"""
|
||||
|
||||
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
||||
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Handler class for sequence parallel trainer logic"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
|
||||
class SequenceParallelHandler:
|
||||
"""
|
||||
Handler class that encapsulates sequence parallelism functionality.
|
||||
This replaces the SequenceParallelMixin with a composition-based approach.
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
"""
|
||||
Initialize the sequence parallel handler.
|
||||
|
||||
Args:
|
||||
args: The arguments object containing sequence parallelism settings.
|
||||
"""
|
||||
self.args = args
|
||||
self.ring_attn_group = None
|
||||
|
||||
# Set up sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset,
|
||||
shuffle=True,
|
||||
is_eval=False,
|
||||
):
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self, dataset):
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset):
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
@@ -3,7 +3,12 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TrainerMixins(
|
||||
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
||||
):
|
||||
"""Stub class combining all mixins for Axolotl trainers."""
|
||||
|
||||
@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""
|
||||
mixin for method override to load RNG states from a checkpoint
|
||||
"""
|
||||
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
# TODO(Dan): remove
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -7,7 +8,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
def _get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset
|
||||
dataset: The training dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
@@ -130,53 +130,3 @@ class SequenceParallelMixin:
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
|
||||
@@ -13,11 +13,10 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
"""Mixin class for the Axolotl training args."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
|
||||
@@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from accelerate.logging import get_logger
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
Process group for ring attention.
|
||||
ring_attn_group: Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
|
||||
"""
|
||||
Patch flash attention a second time to handle batched data. This is a hack to
|
||||
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
|
||||
specified in the Axolotl config.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
"""
|
||||
# Store the original flash attention function
|
||||
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
|
||||
def sequential_batch_flash_attention(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
dropout: float = 0.0,
|
||||
scaling: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
# Check if we have a batch dimension > 1
|
||||
batch_size = query.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return original_flash_attention(
|
||||
module,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Process each item in the batch separately
|
||||
outputs = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract single batch item
|
||||
q_item = query[i:i+1]
|
||||
k_item = key[i:i+1]
|
||||
v_item = value[i:i+1]
|
||||
|
||||
# Handle attention mask - it might be None or have different shapes
|
||||
mask_item = None
|
||||
if attention_mask is not None:
|
||||
# The mask could have different formats depending on implementation
|
||||
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
|
||||
mask_item = attention_mask[i:i+1]
|
||||
else:
|
||||
# For broadcast masks that don't have a batch dimension
|
||||
mask_item = attention_mask
|
||||
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = q_item.shape[0]
|
||||
seq_len = q_item.shape[2]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
# Call the original function for a single batch item
|
||||
output, _ = original_flash_attention(
|
||||
module,
|
||||
q_item,
|
||||
k_item,
|
||||
v_item,
|
||||
mask_item,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Concatenate results along batch dimension
|
||||
concatenated_output = torch.cat(outputs, dim=0)
|
||||
return concatenated_output, None
|
||||
|
||||
# Replace the original function with our sequential version
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
@@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
|
||||
|
||||
@@ -1351,9 +1351,7 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
model_loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1362,12 +1360,16 @@ def load_model(
|
||||
reference_model=reference_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model_loader.load_model()
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
adapter: str | None,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
@@ -1380,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
def load_llama_adapter(
|
||||
model: PreTrainedModel, cfg: DictDefault
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
@@ -1405,7 +1408,7 @@ def load_llama_adapter(model, cfg):
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def find_all_linear_names(model):
|
||||
def find_all_linear_names(model: PreTrainedModel):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
|
||||
Reference in New Issue
Block a user