stronger subclassing of TRL GRPO trainer; custom distributed sampler

This commit is contained in:
Dan Saunders
2025-04-17 04:06:18 +00:00
parent 76e2d2e60b
commit b13b6e185f
15 changed files with 778 additions and 120 deletions

View File

@@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@@ -125,7 +126,7 @@ def load_preference_datasets(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
if cfg.rl is RLType.GRPO:
total_num_steps = None
if cli_args.debug or cfg.debug:

View File

@@ -84,7 +84,7 @@ from axolotl.utils.collators import (
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -1054,7 +1054,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
@@ -1062,13 +1062,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1082,14 +1082,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
@@ -1127,7 +1127,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
if self.cfg.rl is RLType.IPO:
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
@@ -1138,21 +1138,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl == "grpo":
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["kto"]:
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
@@ -1179,7 +1179,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
if self.cfg.rl in [RLType.DPO, RLType.IPO] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)

View File

@@ -3,6 +3,7 @@ DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
from axolotl.utils.schemas.enums import RLType
class DPOStrategy:
@@ -23,7 +24,7 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None

View File

@@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
"""Axolotl GRPO Config for GRPO training"""

View File

@@ -0,0 +1,124 @@
"""
Repeat random sampler (akin to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequencee parallel group.
"""
from typing import Optional, Sized
import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""
Sampler for GRPO training with sequence parallelism that ensures:
1. Ranks in the same sequence parallel group receive identical data
2. Each index is repeated multiple times for sampling different completions
3. Entire batches are repeated for reuse in multiple updates
"""
def __init__(
self,
dataset: Sized,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
world_size: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
self.dataset = dataset
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
self.epoch = 0
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
):
# Drop last incomplete batch if drop_last is True
self.num_samples_per_sp_group = (
self.total_size // self.batch_size // self.num_sp_groups
) * self.batch_size
else:
# Round up to include last batch if drop_last is False
self.num_samples_per_sp_group = (
(self.total_size + self.batch_size * self.num_sp_groups - 1)
// (self.batch_size * self.num_sp_groups)
* self.batch_size
)
def __iter__(self):
# Deterministically shuffle based on epoch and seed
if self.shuffle:
# Use same seed for all ranks in the same SP group
g = torch.Generator()
seed_value = self.seed + self.epoch + self.sp_group_id * 10000
g.manual_seed(seed_value)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible by batch_size
if len(indices) % self.batch_size != 0:
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size
end_idx = min(start_idx + self.batch_size, len(indices))
if start_idx < len(indices):
for j in range(self.batch_size):
if start_idx + j < end_idx:
batch_indices.append(indices[start_idx + j])
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
if self.drop_last:
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
target_len = self.batch_size * num_batches_per_sp_group
if len(batch_indices) > target_len:
batch_indices = batch_indices[:target_len]
# Apply the GRPO repeat pattern
final_indices = []
for _ in range(self.repeat_count):
for idx in batch_indices:
for _ in range(self.mini_repeat_count):
final_indices.append(idx)
return iter(final_indices)
def __len__(self):
# Total length including all repetitions
return (
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
)
def set_epoch(self, epoch):
"""Sets the epoch for this sampler"""
self.epoch = epoch

View File

@@ -1,28 +1,186 @@
"""Axolotl GRPO trainer"""
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
import torch
import torch.distributed as dist
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from trl.trainer.utils import selective_log_softmax
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
)
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
Sampler,
)
from transformers import Trainer, is_wandb_available
from transformers.trainer_utils import seed_worker
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available, is_rich_available
from trl.models import unwrap_model_for_generation
from trl.trainer.utils import (
pad,
print_prompt_completions_sample,
selective_log_softmax,
)
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_deepspeed_available():
import deepspeed
if is_wandb_available():
import wandb
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
# Call parent constructor with all arguments
super().__init__(*args, **kwargs)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
def _get_train_sampler(self) -> Sampler:
# Get distributed training info
world_size = dist.get_world_size()
rank = dist.get_rank()
effective_batch_size = (
self.args.per_device_train_batch_size
* world_size
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
repeat_count=self.num_iterations,
sequence_parallel_degree=self.args.sequence_parallel_degree,
world_size=world_size,
rank=rank,
shuffle=True,
seed=self.args.seed,
drop_last=True,
)
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Initialize SP group attributes if sequence parallelism is enabled
if self.args.sequence_parallel_degree > 1:
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
# Add debug print before any modifications
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns(
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
return dataloader
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
@@ -70,20 +228,376 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
def _generate_and_score_completions(
self, inputs: dict[str | torch.Tensor | Any]
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
prompt_ids, prompt_mask = (
prompt_inputs["input_ids"],
prompt_inputs["attention_mask"],
)
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# This ensures identical inputs for sequence-parallel processing
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts) * self.local_world_size,
)
# Take the full SP group's worth of completions
completion_ids = completion_ids[process_slice]
else:
# Original behavior for non-sequence-parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb
ipdb.set_trace()
dist.barrier()
# Pad the completions, and concatenate them with the prompts
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(
self.model_wrapped,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids,
attention_mask=prompt_mask,
generation_config=self.generation_config,
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
completions.append(
[{"role": "assistant", "content": bootstrap + completion}]
)
else:
completions = completions_text
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = (
f"reward {reward_func.config._name_or_path.split('/')[-1]}"
)
else:
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [
key for key in inputs[0] if key not in ["prompt", "completion"]
]
reward_kwargs = {
key: [example[key] for example in inputs] for key in keys
}
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
# Convert None values to NaN
output_reward_func = [
reward if reward is not None else torch.nan
for reward in output_reward_func
]
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = (
torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
)
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items()
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (
rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
self._total_train_tokens += (
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
)
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
completion_length = (
self.accelerator.gather_for_metrics(completion_mask.sum(1))
.float()
.mean()
.item()
)
self._metrics[mode]["completion_length"].append(completion_length)
# Calculate mean reward per function, but only for samples where the function was applied
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
# Only calculate mean for samples where this reward function was applied (non-NaN values)
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
self._metrics[mode]["reward"].append(rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
if (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
):
prompts_to_log = gather_object(prompts_text)
completions_to_log = gather_object(completions_text)
rewards_to_log = rewards.tolist()
if self.accelerator.is_main_process:
if is_rich_available():
print_prompt_completions_sample(
prompts_to_log,
completions_to_log,
rewards_to_log,
self.state.global_step,
)
if (
self.args.report_to
and "wandb" in self.args.report_to
and wandb.run is not None
):
import pandas as pd
# For logging
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": prompts_to_log,
"completion": completions_to_log,
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)
wandb.log({"completions": wandb.Table(dataframe=df)})
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": old_per_token_logps,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if dist.get_rank() == 0:
import ipdb; ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb; ipdb.set_trace()
dist.barrier()
if self.args.sequence_parallel_degree > 1:
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}")
print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}")
print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}")
# Pad sequence if needed
total_seq_len = input_ids.shape[1]
@@ -123,7 +637,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Calculate if this rank contains any tokens we need to keep
tokens_before_our_slice = self.local_rank * slice_size
print(f"{self.local_rank}: slice_size: {slice_size}")
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}")
print(
f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}"
)
if tokens_before_our_slice < logits_to_keep:
# How many tokens from our slice are needed
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
@@ -132,59 +648,76 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# This rank doesn't contain any tokens we need to keep
logits_to_keep = 0
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
).logits
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
print(f"{self.local_rank}: logits.shape: {logits.shape}")
print(f"{self.local_rank}: logits.shape: {logits.shape}")
# First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device)
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)]
dist.all_gather(all_shapes, local_shape, group=sp_group)
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
else:
gathered_logits = None
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device
# First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor(
[logits.shape[0], logits.shape[1], logits.shape[2]],
device=logits.device,
)
# Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=sp_group)
logits = concatenated_logits
all_shapes = [
torch.zeros_like(local_shape) for _ in range(self.local_world_size)
]
dist.all_gather(all_shapes, local_shape, group=self.sp_group)
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(
torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)
)
else:
gathered_logits = None
dist.barrier()
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=self.sp_group)
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device,
)
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=self.sp_group)
logits = concatenated_logits
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
dist.barrier()
return selective_log_softmax(
logits, input_ids
) # compute logprobs for the input tokens
else:
super()._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep
)

View File

@@ -9,7 +9,7 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass

View File

@@ -4,7 +4,6 @@
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,

View File

@@ -28,7 +28,7 @@ from transformers.modeling_flash_attention_utils import (
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
from axolotl.utils.schemas.enums import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,

View File

@@ -6,14 +6,13 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
from enum import Enum
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
configure_logging()
LOG = get_logger(__name__)
@@ -43,17 +42,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,

View File

@@ -34,6 +34,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
try:
@@ -108,7 +109,7 @@ def setup_reference_model(
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.rl and cfg.rl != RLType.ORPO:
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")

View File

@@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
def _get_path(ds_hash, cfg):
@@ -80,7 +81,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -100,7 +101,7 @@ def drop_long_rl_seq(
len_prompt + len_rejected
) <= sequence_len
if rl == "kto":
if rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -114,7 +115,7 @@ def drop_long_rl_seq(
return (len_prompt + len_completion) <= sequence_len
if rl == "grpo":
if rl is RLType.GRPO:
return True
raise ValueError("Unknown RL type")
@@ -137,9 +138,9 @@ def load_prepare_preference_datasets(cfg):
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl == "orpo":
if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl == "kto":
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
@@ -150,7 +151,7 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl == "kto":
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):

View File

@@ -72,6 +72,7 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
@@ -1340,7 +1341,7 @@ class ModelLoader:
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
self.cfg.adapter
and self.cfg.rl in ["dpo", "ipo", "kto"]
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(

View File

@@ -29,7 +29,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RLType
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -261,7 +261,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: str | None = None
ring_attn_func: RingAttnFunc | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -785,7 +785,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
if self.rl is RLType.SIMPO and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)

View File

@@ -6,12 +6,12 @@ from enum import Enum
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -53,3 +53,14 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"