stronger subclassing of TRL GRPO trainer; custom distributed sampler

This commit is contained in:
Dan Saunders
2025-04-17 04:06:18 +00:00
parent 76e2d2e60b
commit b13b6e185f
15 changed files with 778 additions and 120 deletions

View File

@@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -125,7 +126,7 @@ def load_preference_datasets(
total_num_steps: Optional[int] = int( total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
if cfg.rl == "grpo": if cfg.rl is RLType.GRPO:
total_num_steps = None total_num_steps = None
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:

View File

@@ -84,7 +84,7 @@ from axolotl.utils.collators import (
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try: try:
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo # pylint: disable=ungrouped-imports
@@ -1054,7 +1054,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = None training_args_cls = None
blocklist_args_kwargs = [] blocklist_args_kwargs = []
if self.cfg.rl == "simpo": if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
@@ -1062,13 +1062,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo": elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto": elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
@@ -1082,14 +1082,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo": elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else: else:
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo": if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_completion_length"] = None
@@ -1127,7 +1127,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps): def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps) training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {} dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo": if self.cfg.rl is RLType.IPO:
if self.cfg.dpo_label_smoothing: if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset: if self.eval_dataset:
@@ -1138,21 +1138,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["precompute_ref_log_probs"] = ( dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs self.cfg.precompute_ref_log_probs
) )
if self.cfg.rl == "grpo": if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class() trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo": elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
elif self.cfg.rl in ["kto"]: elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]: elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
else: else:
@@ -1179,7 +1179,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.fsdp: if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: if self.cfg.rl in [RLType.DPO, RLType.IPO] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer) dpo_trainer = self.hook_post_create_trainer(dpo_trainer)

View File

@@ -3,6 +3,7 @@ DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
from axolotl.utils.schemas.enums import RLType
class DPOStrategy: class DPOStrategy:
@@ -23,7 +24,7 @@ class DPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {} training_args_kwargs = {}
if cfg.rl == "ipo": if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_completion_length"] = None

View File

@@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass @dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
""" """Axolotl GRPO Config for GRPO training"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,124 @@
"""
Repeat random sampler (akin to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequencee parallel group.
"""
from typing import Optional, Sized
import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""
Sampler for GRPO training with sequence parallelism that ensures:
1. Ranks in the same sequence parallel group receive identical data
2. Each index is repeated multiple times for sampling different completions
3. Entire batches are repeated for reuse in multiple updates
"""
def __init__(
self,
dataset: Sized,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
world_size: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
self.dataset = dataset
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
self.epoch = 0
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
):
# Drop last incomplete batch if drop_last is True
self.num_samples_per_sp_group = (
self.total_size // self.batch_size // self.num_sp_groups
) * self.batch_size
else:
# Round up to include last batch if drop_last is False
self.num_samples_per_sp_group = (
(self.total_size + self.batch_size * self.num_sp_groups - 1)
// (self.batch_size * self.num_sp_groups)
* self.batch_size
)
def __iter__(self):
# Deterministically shuffle based on epoch and seed
if self.shuffle:
# Use same seed for all ranks in the same SP group
g = torch.Generator()
seed_value = self.seed + self.epoch + self.sp_group_id * 10000
g.manual_seed(seed_value)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible by batch_size
if len(indices) % self.batch_size != 0:
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size
end_idx = min(start_idx + self.batch_size, len(indices))
if start_idx < len(indices):
for j in range(self.batch_size):
if start_idx + j < end_idx:
batch_indices.append(indices[start_idx + j])
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
if self.drop_last:
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
target_len = self.batch_size * num_batches_per_sp_group
if len(batch_indices) > target_len:
batch_indices = batch_indices[:target_len]
# Apply the GRPO repeat pattern
final_indices = []
for _ in range(self.repeat_count):
for idx in batch_indices:
for _ in range(self.mini_repeat_count):
final_indices.append(idx)
return iter(final_indices)
def __len__(self):
# Total length including all repetitions
return (
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
)
def set_epoch(self, epoch):
"""Sets the epoch for this sampler"""
self.epoch = epoch

View File

@@ -1,28 +1,186 @@
"""Axolotl GRPO trainer""" """Axolotl GRPO trainer"""
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any
import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.utils import is_deepspeed_available, is_peft_model from accelerate.utils import (
from trl import GRPOTrainer broadcast_object_list,
from trl.extras.profiling import profiling_decorator gather,
from trl.trainer.utils import selective_log_softmax gather_object,
is_peft_model,
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
) )
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
Sampler,
)
from transformers import Trainer, is_wandb_available
from transformers.trainer_utils import seed_worker
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available, is_rich_available
from trl.models import unwrap_model_for_generation
from trl.trainer.utils import (
pad,
print_prompt_completions_sample,
selective_log_softmax,
)
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_deepspeed_available(): if is_deepspeed_available():
import deepspeed import deepspeed
if is_wandb_available():
import wandb
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers""" """Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
# Call parent constructor with all arguments
super().__init__(*args, **kwargs)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
def _get_train_sampler(self) -> Sampler:
# Get distributed training info
world_size = dist.get_world_size()
rank = dist.get_rank()
effective_batch_size = (
self.args.per_device_train_batch_size
* world_size
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
repeat_count=self.num_iterations,
sequence_parallel_degree=self.args.sequence_parallel_degree,
world_size=world_size,
rank=rank,
shuffle=True,
seed=self.args.seed,
drop_last=True,
)
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Initialize SP group attributes if sequence parallelism is enabled
if self.args.sequence_parallel_degree > 1:
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
# Add debug print before any modifications
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns(
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
return dataloader
@profiling_decorator @profiling_decorator
def _move_model_to_vllm(self): def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations # For DeepSpeed ZeRO-3, we need to gather all parameters before operations
@@ -70,20 +228,376 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache() self.vllm_client.reset_prefix_cache()
def _generate_and_score_completions(
self, inputs: dict[str | torch.Tensor | Any]
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
prompt_ids, prompt_mask = (
prompt_inputs["input_ids"],
prompt_inputs["attention_mask"],
)
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# This ensures identical inputs for sequence-parallel processing
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts) * self.local_world_size,
)
# Take the full SP group's worth of completions
completion_ids = completion_ids[process_slice]
else:
# Original behavior for non-sequence-parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb
ipdb.set_trace()
dist.barrier()
# Pad the completions, and concatenate them with the prompts
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(
self.model_wrapped,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids,
attention_mask=prompt_mask,
generation_config=self.generation_config,
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
completions.append(
[{"role": "assistant", "content": bootstrap + completion}]
)
else:
completions = completions_text
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = (
f"reward {reward_func.config._name_or_path.split('/')[-1]}"
)
else:
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [
key for key in inputs[0] if key not in ["prompt", "completion"]
]
reward_kwargs = {
key: [example[key] for example in inputs] for key in keys
}
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
# Convert None values to NaN
output_reward_func = [
reward if reward is not None else torch.nan
for reward in output_reward_func
]
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = (
torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
)
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items()
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (
rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
self._total_train_tokens += (
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
)
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
completion_length = (
self.accelerator.gather_for_metrics(completion_mask.sum(1))
.float()
.mean()
.item()
)
self._metrics[mode]["completion_length"].append(completion_length)
# Calculate mean reward per function, but only for samples where the function was applied
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
# Only calculate mean for samples where this reward function was applied (non-NaN values)
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
self._metrics[mode]["reward"].append(rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
if (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
):
prompts_to_log = gather_object(prompts_text)
completions_to_log = gather_object(completions_text)
rewards_to_log = rewards.tolist()
if self.accelerator.is_main_process:
if is_rich_available():
print_prompt_completions_sample(
prompts_to_log,
completions_to_log,
rewards_to_log,
self.state.global_step,
)
if (
self.args.report_to
and "wandb" in self.args.report_to
and wandb.run is not None
):
import pandas as pd
# For logging
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": prompts_to_log,
"completion": completions_to_log,
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)
wandb.log({"completions": wandb.Table(dataframe=df)})
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": old_per_token_logps,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
# Get the per-token log probabilities for the completions for the model and the reference model # Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if dist.get_rank() == 0:
import ipdb; ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb; ipdb.set_trace()
dist.barrier()
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
sp_group = get_ring_attn_group() print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}")
self.local_rank = dist.get_rank(group=sp_group) print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}")
self.local_world_size = dist.get_world_size(group=sp_group) print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}")
# Pad sequence if needed # Pad sequence if needed
total_seq_len = input_ids.shape[1] total_seq_len = input_ids.shape[1]
@@ -123,7 +637,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Calculate if this rank contains any tokens we need to keep # Calculate if this rank contains any tokens we need to keep
tokens_before_our_slice = self.local_rank * slice_size tokens_before_our_slice = self.local_rank * slice_size
print(f"{self.local_rank}: slice_size: {slice_size}") print(f"{self.local_rank}: slice_size: {slice_size}")
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}") print(
f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}"
)
if tokens_before_our_slice < logits_to_keep: if tokens_before_our_slice < logits_to_keep:
# How many tokens from our slice are needed # How many tokens from our slice are needed
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
@@ -132,59 +648,76 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# This rank doesn't contain any tokens we need to keep # This rank doesn't contain any tokens we need to keep
logits_to_keep = 0 logits_to_keep = 0
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}") print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits logits = model(
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
).logits
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
print(f"{self.local_rank}: logits.shape: {logits.shape}") print(f"{self.local_rank}: logits.shape: {logits.shape}")
# First, let all ranks know the shape of each rank's tensor # First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device) local_shape = torch.tensor(
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)] [logits.shape[0], logits.shape[1], logits.shape[2]],
dist.all_gather(all_shapes, local_shape, group=sp_group) device=logits.device,
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
else:
gathered_logits = None
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device
) )
all_shapes = [
torch.zeros_like(local_shape) for _ in range(self.local_world_size)
]
dist.all_gather(all_shapes, local_shape, group=self.sp_group)
# Broadcast the result back to all ranks # Use a list-based approach to collect logits of different sizes
dist.broadcast(concatenated_logits, src=0, group=sp_group) if self.local_rank == 0:
logits = concatenated_logits # Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(
torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)
)
else:
gathered_logits = None
input_ids = input_ids[:, -logits_to_keep:] # Gather to rank 0
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. dist.gather(logits, gathered_logits, dst=0, group=self.sp_group)
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
dist.barrier() # On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device,
)
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens # Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=self.sp_group)
logits = concatenated_logits
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
dist.barrier()
return selective_log_softmax(
logits, input_ids
) # compute logprobs for the input tokens
else:
super()._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep
)

View File

@@ -9,7 +9,7 @@ from PIL.Image import Resampling
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass @dataclass

View File

@@ -4,7 +4,6 @@
# flake8: noqa # flake8: noqa
from .patch import ( from .patch import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
register_ring_attn, register_ring_attn,
set_ring_attn_group, set_ring_attn_group,

View File

@@ -28,7 +28,7 @@ from transformers.modeling_flash_attention_utils import (
) )
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
RING_ATTN_FUNC_MAPPING = { RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func, RingAttnFunc.BATCH_RING: ring_flash_attn_func,

View File

@@ -6,14 +6,13 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2. their sequence parallel version of Flash Attention 2.
""" """
from enum import Enum
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -43,17 +42,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn( def register_ring_attn(
sequence_parallel_degree: int, sequence_parallel_degree: int,
heads_k_stride: int | None, heads_k_stride: int | None,

View File

@@ -34,6 +34,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
try: try:
@@ -108,7 +109,7 @@ def setup_reference_model(
Reference model if needed for RL training, `None` otherwise. Reference model if needed for RL training, `None` otherwise.
""" """
model_ref = None model_ref = None
if cfg.rl and cfg.rl != "orpo": if cfg.rl and cfg.rl != RLType.ORPO:
if cfg.adapter and not cfg.rl_adapter_ref_model: if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap # use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer") LOG.debug("Passing model_ref: None to RL trainer")

View File

@@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger("axolotl") LOG = logging.getLogger(__name__)
def _get_path(ds_hash, cfg): def _get_path(ds_hash, cfg):
@@ -80,7 +81,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq( def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
): ):
if rl in ("dpo", "ipo", "orpo", "simpo"): if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not ( if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected") sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
): ):
@@ -100,7 +101,7 @@ def drop_long_rl_seq(
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len ) <= sequence_len
if rl == "kto": if rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")): if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets") raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -114,7 +115,7 @@ def drop_long_rl_seq(
return (len_prompt + len_completion) <= sequence_len return (len_prompt + len_completion) <= sequence_len
if rl == "grpo": if rl is RLType.GRPO:
return True return True
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
@@ -137,9 +138,9 @@ def load_prepare_preference_datasets(cfg):
if _type: if _type:
if isinstance(_type, DictDefault): if isinstance(_type, DictDefault):
_type = "user_defined.default" _type = "user_defined.default"
if _cfg.rl == "orpo": if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl == "kto": elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else: else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
@@ -150,7 +151,7 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = map_dataset( split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
) )
elif _cfg.rl == "kto": elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {} map_kwargs = {}
if isinstance(ds_transform_fn, tuple): if isinstance(ds_transform_fn, tuple):

View File

@@ -72,6 +72,7 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -1340,7 +1341,7 @@ class ModelLoader:
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if ( if (
self.cfg.adapter self.cfg.adapter
and self.cfg.rl in ["dpo", "ipo", "kto"] and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora and not self.cfg.merge_lora
): ):
_, lora_config = load_lora( _, lora_config = load_lora(

View File

@@ -29,7 +29,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset, StepwiseSupervisedDataset,
) )
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RLType from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.integrations import ( from axolotl.utils.schemas.integrations import (
CometConfig, CometConfig,
GradioConfig, GradioConfig,
@@ -261,7 +261,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None heads_k_stride: int | None = None
ring_attn_func: str | None = None ring_attn_func: RingAttnFunc | None = None
special_tokens: SpecialTokensConfig | None = None special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None tokens: list[str] | None = None
@@ -785,7 +785,7 @@ class AxolotlInputConfig(
@model_validator(mode="after") @model_validator(mode="after")
def check_simpo_warmup(self): def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio: if self.rl is RLType.SIMPO and self.warmup_ratio:
raise ValueError( raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
) )

View File

@@ -6,12 +6,12 @@ from enum import Enum
class RLType(str, Enum): class RLType(str, Enum):
"""RL trainer type configuration subset""" """RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name DPO = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name GRPO = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name IPO = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name ORPO = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name KTO = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name SIMPO = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):
@@ -53,3 +53,14 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"