reworking SP logic into composed handler

This commit is contained in:
Dan Saunders
2025-04-04 02:25:00 +00:00
parent ce07081d6c
commit 9f30d3d33a
9 changed files with 341 additions and 106 deletions

View File

@@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
@@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer.add_callback(callback)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -8,10 +8,11 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Literal
from typing import Any, Literal
import datasets
import torch
import torch.nn as nn
from datasets import Dataset
from torch.utils.data import (
BatchSampler,
@@ -25,6 +26,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -61,9 +63,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
@@ -124,7 +124,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
@@ -160,7 +160,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
@@ -232,7 +232,10 @@ class AxolotlTrainer(TrainerMixins, Trainer):
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
dataloader = self.accelerator.prepare_data_loader(dataloader)
return dataloader
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
@@ -341,7 +344,57 @@ class AxolotlTrainer(TrainerMixins, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping of inputs.
num_items_in_batch: The number of items in the batch.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
@override
def compute_loss(

View File

@@ -1,10 +1,7 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for axolotl"""
import gc
from functools import wraps
from typing import Any, Dict, Union
from typing import Any
import torch
from peft.optimizers import create_loraplus_optimizer
@@ -13,6 +10,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -24,17 +22,17 @@ if is_sagemaker_mp_enabled():
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
"""Extend the base DPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
def create_optimizer(self):
# pylint: disable=duplicate-code
@@ -88,7 +86,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
) -> dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
@@ -117,10 +115,9 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
inputs: dict[str, torch.Tensor | Any | None],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
return super().training_step(model, inputs, num_items_in_batch)

View File

@@ -0,0 +1,3 @@
"""Init for trainer handlers"""
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler

View File

@@ -0,0 +1,123 @@
"""Handler class for sequence parallel trainer logic"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DistributedSampler
class SequenceParallelHandler:
"""
Handler class that encapsulates sequence parallelism functionality.
This replaces the SequenceParallelMixin with a composition-based approach.
"""
def __init__(self, args=None):
"""
Initialize the sequence parallel handler.
Args:
args: The arguments object containing sequence parallelism settings.
"""
self.args = args
self.ring_attn_group = None
# Set up sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
from ring_flash_attn import update_ring_flash_attn_params
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.update_ring_flash_attn_params = update_ring_flash_attn_params
self.ring_attn_group = get_ring_attn_group()
def create_sequence_parallel_sampler(
self,
dataset,
shuffle=True,
is_eval=False,
):
"""
Helper method to create sampler for sequence parallelism (SP).
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset):
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset):
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -6,10 +6,9 @@
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin
OptimizerMixin, RngLoaderMixin, SchedulerMixin
):
"""Stub class combining all mixins for Axolotl trainers."""

View File

@@ -1,4 +1,5 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
# TODO(Dan): remove
import logging
from typing import Any
@@ -7,7 +8,6 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
def _get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
@@ -130,53 +130,3 @@ class SequenceParallelMixin:
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging()
LOG = get_logger(__name__)
@@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
ring_attn_group: Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
"""
Patch flash attention a second time to handle batched data. This is a hack to
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
specified in the Axolotl config.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
# Store the original flash attention function
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
def sequential_batch_flash_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
sliding_window: int | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# Check if we have a batch dimension > 1
batch_size = query.shape[0]
if batch_size <= 1:
return original_flash_attention(
module,
query,
key,
value,
attention_mask,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
# Process each item in the batch separately
outputs = []
for i in range(batch_size):
# Extract single batch item
q_item = query[i:i+1]
k_item = key[i:i+1]
v_item = value[i:i+1]
# Handle attention mask - it might be None or have different shapes
mask_item = None
if attention_mask is not None:
# The mask could have different formats depending on implementation
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
mask_item = attention_mask[i:i+1]
else:
# For broadcast masks that don't have a batch dimension
mask_item = attention_mask
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = q_item.shape[0]
seq_len = q_item.shape[2]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
# Call the original function for a single batch item
output, _ = original_flash_attention(
module,
q_item,
k_item,
v_item,
mask_item,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
outputs.append(output)
dist.barrier()
# Concatenate results along batch dimension
concatenated_output = torch.cat(outputs, dim=0)
return concatenated_output, None
# Replace the original function with our sequential version
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)

View File

@@ -1351,9 +1351,7 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
"""Load a model for a given configuration and tokenizer."""
model_loader = ModelLoader(
cfg,
tokenizer,
@@ -1362,12 +1360,16 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return model_loader.load_model()
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
@@ -1380,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False):
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PreTrainedModel, PeftConfig | None]:
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
@@ -1405,7 +1408,7 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(model):
def find_all_linear_names(model: PreTrainedModel):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():