Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
f9c7c3bb72 don't use is_main_process during config validation (#2569) 2025-04-26 14:14:52 -04:00
Wing Lian
caf5cb63ea add e2e smoke test for using activation/gradient checkpointing with offload (#2565)
* add e2e smoke test for using activation/gradient checkpointing with offload

* disable duplicate code check for the test

* fix relative import

* seq len too small to test this dataset with packing

* Fix checkpoint ptaching for tests
2025-04-25 21:11:17 -04:00
Wing Lian
5dba5c82a8 fix support for wandb run_name for rl trainers (#2566) [skip ci]
* fix support for wandb run_name for rl trainers

* prefer to use wandb random names for run_name
2025-04-25 21:10:54 -04:00
Chiwan Park
e3c9d541a7 fix: crash when pretraining_dataset with dispatch_batches is false (#2558) 2025-04-25 17:15:03 -04:00
NanoCode012
9eba0ad118 chore(doc): update docker tags on doc (#2559) [skip ci] 2025-04-25 17:14:48 -04:00
Wing Lian
53dbf97d85 make cce default to true when using the plugin (#2562) [skip ci] 2025-04-25 17:14:26 -04:00
Eko Julianto Salim
2c2563bc34 fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error

* lint

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-25 17:02:37 -04:00
Wing Lian
5cb3398460 don't fail on codecov upload for external contributor PRs (#2564) [skip ci] 2025-04-25 15:10:55 -04:00
Dan Saunders
ae1c7ace63 Sequence parallel training context manager (#2553)
* ctx manager for SP

* updates

* update

* further simplifying

* accommodate both training context managers

* simplifying

* simplifying

* nit

* reorg

* tweak codecov yaml

* add gather post hook, simplify, fixes

* pytest

* pytest fix
2025-04-25 10:33:54 -04:00
28 changed files with 214 additions and 1136 deletions

View File

@@ -8,6 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -52,4 +52,4 @@ pytest -v --durations=10 \
--cov-append \
--cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true

View File

@@ -28,6 +28,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
@@ -50,7 +52,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
main-latest
# nightly build
@@ -68,6 +70,7 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`

View File

@@ -10,7 +10,6 @@ plugins:
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true

View File

@@ -14,7 +14,6 @@ from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@@ -126,7 +125,7 @@ def load_preference_datasets(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:

View File

@@ -84,7 +84,7 @@ from axolotl.utils.collators import (
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -538,6 +538,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
@@ -1009,8 +1011,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.seed:
training_args_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
@@ -1048,13 +1048,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
@@ -1062,13 +1061,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1082,14 +1081,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
@@ -1122,37 +1121,43 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
trainer_kwargs = {}
if self.cfg.rl is RLType.IPO:
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
if self.cfg.dpo_label_smoothing:
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl is RLType.GRPO:
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl is RLType.ORPO:
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.KTO:
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.SIMPO:
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
@@ -1160,33 +1165,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
dpo_trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)
return trainer
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):

View File

@@ -3,7 +3,6 @@ DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
from axolotl.utils.schemas.enums import RLType
class DPOStrategy:
@@ -24,7 +23,7 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None

View File

@@ -11,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -1,124 +0,0 @@
"""
Repeat random sampler (akin to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequencee parallel group.
"""
from typing import Sized
import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""
Sampler for GRPO training with sequence parallelism that ensures:
1. Ranks in the same sequence parallel group receive identical data
2. Each index is repeated multiple times for sampling different completions
3. Entire batches are repeated for reuse in multiple updates
"""
def __init__(
self,
dataset: Sized,
mini_repeat_count: int,
world_size: int,
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
self.dataset = dataset
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
self.epoch = 0
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
):
# Drop last incomplete batch if drop_last is True
self.num_samples_per_sp_group = (
self.total_size // self.batch_size // self.num_sp_groups
) * self.batch_size
else:
# Round up to include last batch if drop_last is False
self.num_samples_per_sp_group = (
(self.total_size + self.batch_size * self.num_sp_groups - 1)
// (self.batch_size * self.num_sp_groups)
* self.batch_size
)
def __iter__(self):
# Deterministically shuffle based on epoch and seed
if self.shuffle:
# Use same seed for all ranks in the same SP group
g = torch.Generator()
seed_value = self.seed + self.epoch + self.sp_group_id * 10000
g.manual_seed(seed_value)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible by batch_size
if len(indices) % self.batch_size != 0:
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size
end_idx = min(start_idx + self.batch_size, len(indices))
if start_idx < len(indices):
for j in range(self.batch_size):
if start_idx + j < end_idx:
batch_indices.append(indices[start_idx + j])
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
if self.drop_last:
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
target_len = self.batch_size * num_batches_per_sp_group
if len(batch_indices) > target_len:
batch_indices = batch_indices[:target_len]
# Apply the GRPO repeat pattern
final_indices = []
for _ in range(self.repeat_count):
for idx in batch_indices:
for _ in range(self.mini_repeat_count):
final_indices.append(idx)
return iter(final_indices)
def __len__(self):
# Total length including all repetitions
return (
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
)
def set_epoch(self, epoch):
"""Sets the epoch for this sampler"""
self.epoch = epoch

View File

@@ -1,279 +1,26 @@
"""Axolotl GRPO trainer"""
"""
Axolotl GRPO trainer
"""
# pylint: disable=too-many-lines,duplicate-code
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
import torch
import torch.distributed as dist
from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
)
from datasets import Dataset, IterableDataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
Sampler,
)
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import (
is_deepspeed_available,
is_rich_available,
)
from trl.models import (
unwrap_model_for_generation,
)
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc
from trl.trainer.utils import (
pad,
print_prompt_completions_sample,
selective_log_softmax,
)
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
if is_wandb_available():
import wandb
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(
self,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: (
Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None
) = None,
processing_class: PreTrainedTokenizerBase | None = None,
reward_processing_classes: (
PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None
) = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
):
# First call the superclass constructor with all arguments
super().__init__(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=reward_processing_classes,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
# Now execute your custom logic
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [
n_gen
for n_gen in range(2, sp_group_batch_size + 1)
if (sp_group_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
f"generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
for n_gen in range(2, sp_group_eval_batch_size + 1)
if (sp_group_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}."
)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
print("end of trainer init")
def _get_train_sampler(self) -> Sampler:
# Get distributed training info
world_size = dist.get_world_size()
rank = dist.get_rank()
effective_batch_size = (
self.args.per_device_train_batch_size
* world_size
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
world_size=world_size,
rank=rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
repeat_count=self.num_iterations,
sequence_parallel_degree=self.args.sequence_parallel_degree,
shuffle=True,
seed=self.args.seed,
drop_last=True,
)
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Initialize SP group attributes if sequence parallelism is enabled
if self.args.sequence_parallel_degree > 1:
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
# Add debug print before any modifications
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
return dataloader
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
@@ -320,577 +67,3 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
# def _generate_and_score_completions(
# self, inputs: list[dict[str, torch.Tensor | Any]]
# ) -> dict[str, torch.Tensor | Any]:
# device = self.accelerator.device
# prompts = [x["prompt"] for x in inputs]
# prompts_text = [
# maybe_apply_chat_template(example, self.processing_class)["prompt"]
# for example in inputs
# ]
# prompt_inputs = self.processing_class(
# text=prompts_text,
# return_tensors="pt",
# padding=True,
# padding_side="left",
# add_special_tokens=False,
# )
# # pylint: disable=protected-access
# prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
# prompt_ids, prompt_mask = (
# prompt_inputs["input_ids"],
# prompt_inputs["attention_mask"],
# )
# if self.max_prompt_length is not None:
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# # Generate completions using either vLLM or regular generation
# if self.args.use_vllm:
# # First, have main process load weights if needed
# # pylint: disable=access-member-before-definition
# if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
# self._move_model_to_vllm()
# # pylint: disable=attribute-defined-outside-init
# self._last_loaded_step = self.state.global_step
# all_prompts_text = gather_object(prompts_text)
# if self.accelerator.is_main_process:
# # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# # num_generations outputs for each one. This is faster than generating outputs for each duplicate
# # prompt individually.
# # ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
# ordered_set_of_prompts = all_prompts_text[
# :: self.num_generations * self.args.sequence_parallel_degree
# ]
# with profiling_context(self, "vLLM.generate"):
# completion_ids = self.vllm_client.generate(
# prompts=ordered_set_of_prompts,
# n=self.num_generations,
# repetition_penalty=self.repetition_penalty,
# temperature=self.temperature,
# top_p=self.top_p,
# top_k=-1 if self.top_k is None else self.top_k,
# min_p=0.0 if self.min_p is None else self.min_p,
# max_tokens=self.max_completion_length,
# guided_decoding_regex=self.guided_decoding_regex,
# )
# else:
# completion_ids = [None] * (
# len(all_prompts_text) // self.args.sequence_parallel_degree
# )
# # Broadcast the completions from the main process to all processes
# completion_ids = broadcast_object_list(completion_ids, from_process=0)
# # Determine the appropriate slice based on sequence parallelism
# if self.args.sequence_parallel_degree > 1:
# # Calculate SP group ID (which group of ranks this rank belongs to)
# sp_group_id = self.accelerator.process_index // self.local_world_size
# # Calculate the start index for this SP group
# sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# # All ranks in the same SP group get the same data slice
# process_slice = slice(
# sp_group_start,
# sp_group_start + len(prompts),
# )
# completion_ids = completion_ids[process_slice]
# else:
# # Original behavior for non-sequence parallel case
# process_slice = slice(
# self.accelerator.process_index * len(prompts),
# (self.accelerator.process_index + 1) * len(prompts),
# )
# completion_ids = completion_ids[process_slice]
# # Pad the completions, and concatenate them with the prompts
# completion_ids = [
# torch.tensor(ids, device=device) for ids in completion_ids
# ]
# completion_ids = pad(
# completion_ids, padding_value=self.processing_class.pad_token_id
# )
# else:
# # Regular generation path
# with unwrap_model_for_generation(
# self.model_wrapped,
# self.accelerator,
# gather_deepspeed3_params=self.args.ds3_gather_for_generation,
# ) as unwrapped_model:
# prompt_completion_ids = unwrapped_model.generate(
# prompt_ids,
# attention_mask=prompt_mask,
# generation_config=self.generation_config,
# )
# # Compute prompt length and extract completion ids
# prompt_length = prompt_ids.size(1)
# prompt_ids = prompt_completion_ids[:, :prompt_length]
# completion_ids = prompt_completion_ids[:, prompt_length:]
# prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# # Mask everything after the first EOS token
# is_eos = completion_ids == self.processing_class.eos_token_id
# eos_idx = torch.full(
# (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
# )
# eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
# sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
# is_eos.size(0), -1
# )
# completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# # Concatenate prompt_mask with completion_mask for logit computation
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
# logits_to_keep = completion_ids.size(
# 1
# ) # we only need to compute the logits for the completion tokens
# with torch.no_grad():
# # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# # computation here, and use per_token_logps.detach() instead.
# if self.num_iterations > 1:
# if self.args.sequence_parallel_degree > 1:
# old_per_token_logps, _ = self._get_per_token_logps_v2(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# old_per_token_logps = super()._get_per_token_logps(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# old_per_token_logps = None
# if self.beta == 0.0:
# ref_per_token_logps = None
# elif self.ref_model is not None:
# if self.args.sequence_parallel_degree > 1:
# ref_per_token_logps, _ = self._get_per_token_logps_v2(
# self.ref_model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# ref_per_token_logps = super()._get_per_token_logps(
# self.ref_model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# with self.accelerator.unwrap_model(self.model).disable_adapter():
# if self.args.sequence_parallel_degree > 1:
# ref_per_token_logps, _ = self._get_per_token_logps_v2(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# ref_per_token_logps = super()._get_per_token_logps(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# # Decode the generated completions
# completions_text = self.processing_class.batch_decode(
# completion_ids, skip_special_tokens=True
# )
# if is_conversational(inputs[0]):
# completions = []
# for prompt, completion in zip(prompts, completions_text):
# bootstrap = (
# prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
# )
# completions.append(
# [{"role": "assistant", "content": bootstrap + completion}]
# )
# else:
# completions = completions_text
# rewards_per_func = torch.zeros(
# len(prompts), len(self.reward_funcs), device=device
# )
# for i, (reward_func, reward_processing_class) in enumerate(
# zip(self.reward_funcs, self.reward_processing_classes)
# ):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# reward_func_name = (
# f"reward {reward_func.config._name_or_path.split('/')[-1]}"
# )
# else:
# # pylint: disable=protected-access
# reward_func_name = reward_func.__name__
# with profiling_context(self, reward_func_name):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# if is_conversational(inputs[0]):
# messages = [
# {"messages": p + c} for p, c in zip(prompts, completions)
# ]
# texts = [
# apply_chat_template(x, reward_processing_class)["text"]
# for x in messages
# ]
# else:
# texts = [p + c for p, c in zip(prompts, completions)]
# reward_inputs = reward_processing_class(
# text=texts,
# return_tensors="pt",
# padding=True,
# padding_side="right",
# add_special_tokens=False,
# )
# # pylint: disable=protected-access
# reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
# with torch.inference_mode():
# rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
# :, 0
# ] # Shape (B*G,)
# else:
# # Repeat all input columns (but "prompt" and "completion") to match the number of generations
# keys = [
# key for key in inputs[0] if key not in ["prompt", "completion"]
# ]
# reward_kwargs = {
# key: [example[key] for example in inputs] for key in keys
# }
# output_reward_func = reward_func(
# prompts=prompts, completions=completions, **reward_kwargs
# )
# # Convert None values to NaN
# output_reward_func = [
# reward if reward is not None else torch.nan
# for reward in output_reward_func
# ]
# rewards_per_func[:, i] = torch.tensor(
# output_reward_func, dtype=torch.float32, device=device
# )
# # If all reward functions return None for a given row, issue a detailed warning
# if torch.isnan(rewards_per_func).all(dim=1).any():
# nan_row_idx = (
# torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
# )
# row_reward_kwargs = {
# key: value[nan_row_idx] for key, value in reward_kwargs.items()
# }
# row_reward_kwargs["prompt"] = prompts[nan_row_idx]
# row_reward_kwargs["completion"] = completions[nan_row_idx]
# warnings.warn(
# f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
# "Please ensure that at least one reward function returns a valid reward."
# )
# # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# # completions may be distributed across processes
# rewards_per_func = gather(rewards_per_func)
# # Apply weights to each reward function's output and sum
# rewards = (
# rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
# ).nansum(dim=1)
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
# self.num_generations, dim=0
# )
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(
# self.num_generations, dim=0
# )
# advantages = rewards - mean_grouped_rewards
# if self.args.scale_rewards:
# advantages = advantages / (std_grouped_rewards + 1e-4)
# # Slice to keep only the local part of the data
# process_slice = slice(
# self.accelerator.process_index * len(prompts),
# (self.accelerator.process_index + 1) * len(prompts),
# )
# advantages = advantages[process_slice]
# # Log the metrics
# mode = "eval" if self.control.should_evaluate else "train"
# if mode == "train":
# # pylint: disable=no-member
# self._total_train_tokens += (
# self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
# )
# # pylint: disable=no-member
# self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
# completion_length = (
# self.accelerator.gather_for_metrics(completion_mask.sum(1))
# .float()
# .mean()
# .item()
# )
# self._metrics[mode]["completion_length"].append(completion_length)
# # Calculate mean reward per function, but only for samples where the function was applied
# for i, reward_func in enumerate(self.reward_funcs):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# reward_func_name = reward_func.config._name_or_path.split("/")[-1]
# else:
# # pylint: disable=protected-access
# reward_func_name = reward_func.__name__
# # Only calculate mean for samples where this reward function was applied (non-NaN values)
# mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
# self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
# self._metrics[mode]["reward"].append(rewards.mean().item())
# self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
# if (
# self.log_completions
# and self.state.global_step % self.args.logging_steps == 0
# ):
# prompts_to_log = gather_object(prompts_text)
# completions_to_log = gather_object(completions_text)
# rewards_to_log = rewards.tolist()
# if self.accelerator.is_main_process:
# if is_rich_available():
# print_prompt_completions_sample(
# prompts_to_log,
# completions_to_log,
# rewards_to_log,
# self.state.global_step,
# )
# if (
# self.args.report_to
# and "wandb" in self.args.report_to
# and wandb.run is not None
# ):
# import pandas as pd
# # For logging
# table = {
# "step": [str(self.state.global_step)] * len(rewards),
# "prompt": prompts_to_log,
# "completion": completions_to_log,
# "reward": rewards.tolist(),
# }
# df = pd.DataFrame(table)
# wandb.log({"completions": wandb.Table(dataframe=df)})
# return {
# "prompt_ids": prompt_ids,
# "prompt_mask": prompt_mask,
# "completion_ids": completion_ids,
# "completion_mask": completion_mask,
# "old_per_token_logps": old_per_token_logps,
# "ref_per_token_logps": ref_per_token_logps,
# "advantages": advantages,
# }
# def _get_per_token_logps_v2(
# self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None
# ):
# # Pad sequence to be divisible by SP degree if needed
# total_seq_len = input_ids.shape[1]
# if total_seq_len % self.local_world_size != 0:
# pad_len = self.local_world_size - (total_seq_len % self.local_world_size)
# pad_token_id = self.processing_class.pad_token_id or 0
# # Pad input_ids and attention_mask
# padding = torch.full(
# (input_ids.shape[0], pad_len),
# pad_token_id,
# dtype=input_ids.dtype,
# device=input_ids.device,
# )
# input_ids = torch.cat([input_ids, padding], dim=1)
# attn_padding = torch.zeros(
# (attention_mask.shape[0], pad_len),
# dtype=attention_mask.dtype,
# device=attention_mask.device,
# )
# attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# if completion_mask is not None:
# completion_mask = torch.cat([completion_mask, attn_padding], dim=1)
# total_seq_len += pad_len
# logits_to_keep += pad_len
# # Split the sequence
# slice_size = total_seq_len // self.local_world_size
# start = self.local_rank * slice_size
# end = start + slice_size
# # Get our slice
# input_ids_slice = input_ids[:, start:end]
# attention_mask_slice = attention_mask[:, start:end]
# # Calculate where our slice starts and ends relative to the completion tokens
# local_completion_mask = None
# prompt_len = input_ids.size(1) - logits_to_keep
# if start >= prompt_len:
# # Slice starts within the completion section
# start_in_completion = start - prompt_len
# end_in_completion = min(end - prompt_len, logits_to_keep)
# local_logits_to_keep = end_in_completion - start_in_completion
# if completion_mask is not None:
# local_completion_mask = completion_mask[
# :, start_in_completion:end_in_completion
# ]
# elif end <= prompt_len:
# # Slice is entirely within the prompt section (no completion tokens)
# local_logits_to_keep = 0
# if completion_mask is not None:
# local_completion_mask = torch.zeros(
# (completion_mask.size(0), 0), device=completion_mask.device
# )
# else:
# # Slice contains the boundary between prompt and completion
# start_in_completion = 0
# end_in_completion = min(end - prompt_len, logits_to_keep)
# local_logits_to_keep = end_in_completion - start_in_completion
# if completion_mask is not None:
# local_completion_mask = completion_mask[
# :, start_in_completion:end_in_completion
# ]
# # Get logits with enough context to compute log probs
# logits = model(
# input_ids=input_ids_slice,
# attention_mask=attention_mask_slice,
# logits_to_keep=local_logits_to_keep + 1,
# ).logits
# # Only the last rank that contains completion tokens needs to remove the last logit
# is_last_rank_with_completions = (
# self.local_rank == self.local_world_size - 1 # Last rank overall
# or end
# >= prompt_len
# + logits_to_keep # Our slice includes the last completion token
# )
# if is_last_rank_with_completions:
# logits = logits[:, :-1]
# if local_completion_mask is not None:
# local_completion_mask = local_completion_mask[:, :-1]
# local_logits_to_keep -= 1
# if start >= prompt_len:
# # For ranks where slice is all completion tokens,
# # we need to offset to match the logits (which predict the next token)
# offset = 1 # Skip the first token as it's predicted by the last token of the previous rank
# local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
# else:
# # For the rank that contains the prompt-completion boundary,
# # we need to take completion tokens only
# offset = prompt_len - start # Where completions start in our slice
# local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
# logits = logits[
# :, -local_logits_to_keep:
# ] # Take only logits for completion tokens
# logits = logits / self.temperature
# per_token_logps = selective_log_softmax(logits, local_input_ids)
# return per_token_logps, local_completion_mask
# # pylint: disable=unused-argument
# @profiling_decorator
# def compute_loss(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# if return_outputs:
# raise ValueError("The GRPOTrainer does not support returning outputs")
# # Unpack inputs
# prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
# completion_ids, completion_mask = (
# inputs["completion_ids"],
# inputs["completion_mask"],
# )
# prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# logits_to_keep = completion_ids.size(1)
# if self.args.sequence_parallel_degree > 1:
# per_token_logps, completion_mask = self._get_per_token_logps_v2(
# model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# completion_mask,
# )
# else:
# per_token_logps = super()._get_per_token_logps(
# model, prompt_completion_ids, attention_mask, logits_to_keep
# )
# # Compute the KL divergence between the model and the reference model
# if self.beta != 0.0:
# ref_per_token_logps = inputs["ref_per_token_logps"]
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps)
# - (ref_per_token_logps - per_token_logps)
# - 1
# )
# # Compute the loss
# advantages = inputs["advantages"]
# # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
# # and use per_token_logps.detach() instead.
# old_per_token_logps = (
# inputs["old_per_token_logps"]
# if self.num_iterations > 1
# else per_token_logps.detach()
# )
# coef_1 = torch.exp(per_token_logps - old_per_token_logps)
# coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
# per_token_loss1 = coef_1 * advantages.unsqueeze(1)
# per_token_loss2 = coef_2 * advantages.unsqueeze(1)
# per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
# if self.beta != 0.0:
# per_token_loss = per_token_loss + self.beta * per_token_kl
# loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# # Log metrics
# mode = "eval" if self.control.should_evaluate else "train"
# if self.beta != 0.0:
# mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
# self._metrics[mode]["kl"].append(
# self.accelerator.gather_for_metrics(mean_kl).mean().item()
# )
# is_clipped = (per_token_loss1 < per_token_loss2).float()
# clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
# self._metrics[mode]["clip_ratio"].append(
# self.accelerator.gather_for_metrics(clip_ratio).mean().item()
# )
# return loss

View File

@@ -13,66 +13,14 @@ from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = logging.getLogger(__name__)
def _handle_logits_to_keep(
logits_to_keep,
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
total_seq_len: int,
):
"""
Handle logits_to_keep parameter for sequence parallelism.
Args:
logits_to_keep: Integer or tensor indicating which positions to compute logits
for.
local_rank: Rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: Ring attention function being used.
total_seq_len: Full sequence length.
Returns:
Adjusted logits_to_keep appropriate for this rank's sharded sequence
"""
print("start of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep)
# No transformation needed if logits_to_keep is None
if logits_to_keep is None:
return None
assert isinstance(
logits_to_keep, int
), "sequence parallelism currently only supports integer logits_to_keep"
assert ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
# For standard sharding, each rank gets a contiguous chunk
chunk_size = total_seq_len // local_world_size
start_idx = local_rank * chunk_size
end_idx = start_idx + chunk_size
# Check if logits_to_keep is in this rank's range
if start_idx <= logits_to_keep < end_idx:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep - start_idx)
return logits_to_keep - start_idx
else:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), -1)
return -1
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
@@ -83,10 +31,10 @@ def apply_sequence_parallelism(
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: The ring attention function to use.
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
@@ -99,10 +47,12 @@ def apply_sequence_parallelism(
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
@@ -127,14 +77,6 @@ def apply_sequence_parallelism(
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
if key == "logits_to_keep":
batch[key] = _handle_logits_to_keep(
logits_to_keep=batch[key],
local_rank=local_rank,
local_world_size=local_world_size,
ring_attn_func=ring_attn_func,
total_seq_len=total_seq_len,
)
return batch
@@ -262,11 +204,8 @@ class SequenceParallelContextManager:
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
print("start of sequence_parallel_post_hook")
# Gather the sharded outputs
output = self.gather_outputs(output)
print("end of sequence_parallel_post_hook")
return output
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(

View File

@@ -9,7 +9,7 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass

View File

@@ -27,8 +27,6 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```
## Supported Models

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
cut_cross_entropy: Optional[bool] = True
@model_validator(mode="before")
@classmethod

View File

@@ -4,6 +4,7 @@
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,

View File

@@ -28,7 +28,7 @@ from transformers.modeling_flash_attention_utils import (
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,

View File

@@ -6,13 +6,14 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
from enum import Enum
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
configure_logging()
LOG = get_logger(__name__)
@@ -42,6 +43,17 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,

View File

@@ -34,7 +34,6 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
try:
@@ -109,7 +108,7 @@ def setup_reference_model(
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != RLType.ORPO:
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")

View File

@@ -18,9 +18,8 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
LOG = logging.getLogger("axolotl")
def _get_path(ds_hash, cfg):
@@ -81,7 +80,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -101,7 +100,7 @@ def drop_long_rl_seq(
len_prompt + len_rejected
) <= sequence_len
if rl is RLType.KTO:
if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -115,7 +114,7 @@ def drop_long_rl_seq(
return (len_prompt + len_completion) <= sequence_len
if rl is RLType.GRPO:
if rl == "grpo":
return True
raise ValueError("Unknown RL type")
@@ -138,9 +137,9 @@ def load_prepare_preference_datasets(cfg):
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl is RLType.ORPO:
if _cfg.rl == "orpo":
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl is RLType.KTO:
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
@@ -151,7 +150,7 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):

View File

@@ -134,10 +134,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,5 +1,7 @@
"""custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)

View File

@@ -72,7 +72,6 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
@@ -1341,7 +1340,7 @@ class ModelLoader:
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
self.cfg.adapter
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and self.cfg.rl in ["dpo", "ipo", "kto"]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(

View File

@@ -18,7 +18,6 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
@@ -28,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.enums import ChatTemplate, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -260,7 +259,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: RingAttnFunc | None = None
ring_attn_func: str | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -719,10 +718,9 @@ class AxolotlInputConfig(
and data.get("eval_sample_packing") is None
and not data.get("eval_table_size")
):
if is_main_process():
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True
if (
@@ -784,7 +782,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl is RLType.SIMPO and self.warmup_ratio:
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
@@ -1179,15 +1177,14 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
if is_main_process():
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self
@@ -1196,6 +1193,8 @@ class AxolotlInputConfig(
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs:

View File

@@ -6,12 +6,12 @@ from enum import Enum
class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -53,14 +53,3 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"

View File

@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:

View File

@@ -0,0 +1,77 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,4 +1,6 @@
"""E2E tests for mixtral"""
"""
E2E tests for mixtral
"""
import logging
import os

View File

@@ -12,12 +12,12 @@ from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
@pytest.fixture
@@ -131,11 +131,6 @@ class TestConfigValidation:
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Mock the is_main_process function to return True
monkeypatch.setattr(
"axolotl.utils.schemas.config.is_main_process", lambda: True
)
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""