subclassing constructor
This commit is contained in:
@@ -1126,23 +1126,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if self.cfg.rl is RLType.IPO:
|
if self.cfg.rl is RLType.IPO:
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
)
|
)
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
trainer_cls = DPOStrategy.get_trainer_class()
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
@@ -1160,33 +1160,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters.keys():
|
if "tokenizer" in sig.parameters.keys():
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (
|
if self.cfg.datasets is not None and (
|
||||||
trainer_cls is DPOStrategy.get_trainer_class()
|
trainer_cls is DPOStrategy.get_trainer_class()
|
||||||
):
|
):
|
||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
dpo_trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and dpo_trainer.ref_model:
|
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||||
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
|
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||||
|
|
||||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
return dpo_trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Axolotl GRPO trainer"""
|
"""Axolotl GRPO trainer"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -12,15 +13,29 @@ from accelerate.utils import (
|
|||||||
gather,
|
gather,
|
||||||
gather_object,
|
gather_object,
|
||||||
is_peft_model,
|
is_peft_model,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
Sampler,
|
Sampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer, is_wandb_available
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
Trainer,
|
||||||
|
TrainerCallback,
|
||||||
|
is_wandb_available,
|
||||||
|
)
|
||||||
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
|
from transformers.utils import is_peft_available
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.data_utils import (
|
from trl.data_utils import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
@@ -28,8 +43,20 @@ from trl.data_utils import (
|
|||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
from trl.extras.profiling import profiling_context, profiling_decorator
|
||||||
from trl.import_utils import is_deepspeed_available, is_rich_available
|
from trl.extras.vllm_client import VLLMClient
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.import_utils import (
|
||||||
|
is_deepspeed_available,
|
||||||
|
is_rich_available,
|
||||||
|
is_vllm_available,
|
||||||
|
)
|
||||||
|
from trl.models import (
|
||||||
|
create_reference_model,
|
||||||
|
prepare_deepspeed,
|
||||||
|
unwrap_model_for_generation,
|
||||||
|
)
|
||||||
|
from trl.trainer.callbacks import SyncRefModelCallback
|
||||||
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
from trl.trainer.utils import (
|
from trl.trainer.utils import (
|
||||||
pad,
|
pad,
|
||||||
print_prompt_completions_sample,
|
print_prompt_completions_sample,
|
||||||
@@ -40,6 +67,9 @@ from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampl
|
|||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
|
||||||
|
|
||||||
|
if is_peft_available():
|
||||||
|
from peft import PeftConfig, get_peft_model
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
@@ -52,9 +82,341 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(
|
||||||
# Call parent constructor with all arguments
|
self,
|
||||||
super().__init__(*args, **kwargs)
|
model: str | PreTrainedModel,
|
||||||
|
reward_funcs: RewardFunc | list[RewardFunc],
|
||||||
|
args: GRPOConfig | None = None,
|
||||||
|
train_dataset: Dataset | IterableDataset | None = None,
|
||||||
|
eval_dataset: (
|
||||||
|
Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None
|
||||||
|
) = None,
|
||||||
|
processing_class: PreTrainedTokenizerBase | None = None,
|
||||||
|
reward_processing_classes: (
|
||||||
|
PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None
|
||||||
|
) = None,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
|
optimizers: tuple[
|
||||||
|
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
|
||||||
|
] = (None, None),
|
||||||
|
peft_config: "PeftConfig | None" = None,
|
||||||
|
):
|
||||||
|
# Args
|
||||||
|
if args is None:
|
||||||
|
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||||
|
model_name = model_name.split("/")[-1]
|
||||||
|
args = GRPOConfig(f"{model_name}-GRPO")
|
||||||
|
|
||||||
|
# Models
|
||||||
|
# Trained model
|
||||||
|
model_init_kwargs = args.model_init_kwargs or {}
|
||||||
|
if isinstance(model, str):
|
||||||
|
model_id = model
|
||||||
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||||
|
if (
|
||||||
|
isinstance(torch_dtype, torch.dtype)
|
||||||
|
or torch_dtype == "auto"
|
||||||
|
or torch_dtype is None
|
||||||
|
):
|
||||||
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||||
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||||
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||||
|
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||||
|
)
|
||||||
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||||
|
model_init_kwargs["use_cache"] = (
|
||||||
|
False
|
||||||
|
if args.gradient_checkpointing
|
||||||
|
else model_init_kwargs.get("use_cache")
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||||
|
else:
|
||||||
|
model_id = model.config._name_or_path
|
||||||
|
if args.model_init_kwargs is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
||||||
|
"This argument can only be used when the `model` argument is a string."
|
||||||
|
)
|
||||||
|
|
||||||
|
if peft_config is not None:
|
||||||
|
if not is_peft_available():
|
||||||
|
raise ImportError(
|
||||||
|
"PEFT is required to use `peft_config`. Run `pip install peft`."
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
|
# Enable gradient checkpointing if requested
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
model = self._enable_gradient_checkpointing(model, args)
|
||||||
|
|
||||||
|
# Reference model
|
||||||
|
self.beta = args.beta
|
||||||
|
if self.beta == 0.0:
|
||||||
|
# If beta is 0.0, the reference model is not needed
|
||||||
|
self.ref_model = None
|
||||||
|
elif is_deepspeed_zero3_enabled():
|
||||||
|
self.ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, **model_init_kwargs
|
||||||
|
)
|
||||||
|
elif is_peft_model(model):
|
||||||
|
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||||
|
# to revert to the initial model.
|
||||||
|
self.ref_model = None
|
||||||
|
else:
|
||||||
|
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||||
|
self.ref_model = create_reference_model(model)
|
||||||
|
|
||||||
|
# Processing class
|
||||||
|
if processing_class is None:
|
||||||
|
processing_class = AutoTokenizer.from_pretrained(
|
||||||
|
model.config._name_or_path, padding_side="left"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reward functions
|
||||||
|
if not isinstance(reward_funcs, list):
|
||||||
|
reward_funcs = [reward_funcs]
|
||||||
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
if isinstance(reward_func, str):
|
||||||
|
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
reward_func, num_labels=1, **model_init_kwargs
|
||||||
|
)
|
||||||
|
self.reward_funcs = reward_funcs
|
||||||
|
|
||||||
|
# Reward weights
|
||||||
|
if args.reward_weights is not None:
|
||||||
|
if len(args.reward_weights) != len(reward_funcs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
||||||
|
f"functions ({len(reward_funcs)})"
|
||||||
|
)
|
||||||
|
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
||||||
|
|
||||||
|
# Reward processing class
|
||||||
|
if reward_processing_classes is None:
|
||||||
|
reward_processing_classes = [None] * len(reward_funcs)
|
||||||
|
elif not isinstance(reward_processing_classes, list):
|
||||||
|
reward_processing_classes = [reward_processing_classes]
|
||||||
|
else:
|
||||||
|
if len(reward_processing_classes) != len(reward_funcs):
|
||||||
|
raise ValueError(
|
||||||
|
"The number of reward processing classes must match the number of reward functions."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (reward_processing_class, reward_func) in enumerate(
|
||||||
|
zip(reward_processing_classes, reward_funcs)
|
||||||
|
):
|
||||||
|
if isinstance(reward_func, PreTrainedModel):
|
||||||
|
if reward_processing_class is None:
|
||||||
|
reward_processing_class = AutoTokenizer.from_pretrained(
|
||||||
|
reward_func.config._name_or_path
|
||||||
|
)
|
||||||
|
if reward_processing_class.pad_token_id is None:
|
||||||
|
reward_processing_class.pad_token = (
|
||||||
|
reward_processing_class.eos_token
|
||||||
|
)
|
||||||
|
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
||||||
|
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
||||||
|
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
||||||
|
reward_processing_classes[i] = reward_processing_class
|
||||||
|
self.reward_processing_classes = reward_processing_classes
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
def data_collator(features): # No data collation is needed in GRPO
|
||||||
|
return features
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
self.max_prompt_length = args.max_prompt_length
|
||||||
|
self.max_completion_length = (
|
||||||
|
args.max_completion_length
|
||||||
|
) # = |o_i| in the GRPO paper
|
||||||
|
self.num_generations = args.num_generations # = G in the GRPO paper
|
||||||
|
self.temperature = args.temperature
|
||||||
|
self.top_p = args.top_p
|
||||||
|
self.top_k = args.top_k
|
||||||
|
self.min_p = args.min_p
|
||||||
|
self.repetition_penalty = args.repetition_penalty
|
||||||
|
self.use_vllm = args.use_vllm
|
||||||
|
|
||||||
|
# Multi-step
|
||||||
|
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
||||||
|
self.epsilon_low = args.epsilon
|
||||||
|
self.epsilon_high = (
|
||||||
|
args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
||||||
|
)
|
||||||
|
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
|
||||||
|
self._step = 0
|
||||||
|
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
||||||
|
# `_get_train_sampler` and `_prepare_inputs`.
|
||||||
|
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||||
|
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||||
|
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
||||||
|
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
||||||
|
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
||||||
|
# This acts as a flag to indicate that the warning has already been issued.
|
||||||
|
model.warnings_issued["estimate_tokens"] = True
|
||||||
|
|
||||||
|
# Initialize the metrics
|
||||||
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||||
|
self._total_train_tokens = 0
|
||||||
|
self.log_completions = args.log_completions
|
||||||
|
|
||||||
|
Trainer.__init__(
|
||||||
|
self,
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=processing_class,
|
||||||
|
callbacks=callbacks,
|
||||||
|
optimizers=optimizers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get number of SP groups (number of processes divided by SP degree)
|
||||||
|
num_processes = self.accelerator.num_processes
|
||||||
|
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
# Calculate batch size per SP group (not per process)
|
||||||
|
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||||
|
possible_values = [
|
||||||
|
n_gen
|
||||||
|
for n_gen in range(2, sp_group_batch_size + 1)
|
||||||
|
if (sp_group_batch_size) % n_gen == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.num_generations not in possible_values:
|
||||||
|
raise ValueError(
|
||||||
|
f"The batch size per SP group ({num_sp_groups} x {self.args.per_device_train_batch_size}) must be evenly "
|
||||||
|
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
||||||
|
f"configuration, the valid values for the number of generations are: {possible_values}."
|
||||||
|
)
|
||||||
|
if self.args.eval_strategy != "no":
|
||||||
|
# If sequence parallelism is enabled, calculate batch size per SP group
|
||||||
|
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups
|
||||||
|
possible_values = [
|
||||||
|
n_gen
|
||||||
|
for n_gen in range(2, sp_group_eval_batch_size + 1)
|
||||||
|
if (sp_group_eval_batch_size) % n_gen == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.num_generations not in possible_values:
|
||||||
|
raise ValueError(
|
||||||
|
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||||
|
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||||
|
f"must be evenly divisible by the number of generations per prompt "
|
||||||
|
f"({self.num_generations}). Given the current eval batch size, "
|
||||||
|
f"the valid values for the number of generations are: {possible_values}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
||||||
|
# num_processes = self.accelerator.num_processes
|
||||||
|
# global_batch_size = args.per_device_train_batch_size * num_processes
|
||||||
|
# possible_values = [
|
||||||
|
# n_gen
|
||||||
|
# for n_gen in range(2, global_batch_size + 1)
|
||||||
|
# if (global_batch_size) % n_gen == 0
|
||||||
|
# ]
|
||||||
|
# if self.num_generations not in possible_values:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
||||||
|
# f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
||||||
|
# f"batch size, the valid values for the number of generations are: {possible_values}."
|
||||||
|
# )
|
||||||
|
# if self.args.eval_strategy != "no":
|
||||||
|
# global_batch_size = args.per_device_eval_batch_size * num_processes
|
||||||
|
# possible_values = [
|
||||||
|
# n_gen
|
||||||
|
# for n_gen in range(2, global_batch_size + 1)
|
||||||
|
# if (global_batch_size) % n_gen == 0
|
||||||
|
# ]
|
||||||
|
# if self.num_generations not in possible_values:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
||||||
|
# f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
||||||
|
# f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
||||||
|
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
||||||
|
# it's safer to set it in all cases.
|
||||||
|
set_seed(args.seed, device_specific=True)
|
||||||
|
|
||||||
|
if self.use_vllm:
|
||||||
|
if not is_vllm_available():
|
||||||
|
raise ImportError(
|
||||||
|
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
||||||
|
"`pip install vllm` to use it."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
self.vllm_client = VLLMClient(
|
||||||
|
args.vllm_server_host,
|
||||||
|
args.vllm_server_port,
|
||||||
|
connection_timeout=args.vllm_server_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# vLLM specific sampling arguments
|
||||||
|
self.guided_decoding_regex = args.vllm_guided_decoding_regex
|
||||||
|
|
||||||
|
self._last_loaded_step = (
|
||||||
|
0 # tag to avoid useless loading during grad accumulation
|
||||||
|
)
|
||||||
|
|
||||||
|
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
||||||
|
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
||||||
|
# synchronize all processes after vLLM has been fully initialized.
|
||||||
|
self.accelerator.wait_for_everyone()
|
||||||
|
else:
|
||||||
|
self.generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=self.max_completion_length,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=processing_class.pad_token_id,
|
||||||
|
bos_token_id=processing_class.bos_token_id,
|
||||||
|
eos_token_id=processing_class.eos_token_id,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
cache_implementation=args.cache_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||||
|
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||||
|
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
# Add tags to the model
|
||||||
|
self.model.add_model_tags(self._tag_names)
|
||||||
|
|
||||||
|
if self.ref_model is not None:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(
|
||||||
|
self.ref_model, evaluation_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.sync_ref_model:
|
||||||
|
self.add_callback(
|
||||||
|
SyncRefModelCallback(
|
||||||
|
ref_model=self.ref_model, accelerator=self.accelerator
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, reward_func in enumerate(self.reward_funcs):
|
||||||
|
if isinstance(reward_func, PreTrainedModel):
|
||||||
|
self.reward_funcs[i] = self.accelerator.prepare_model(
|
||||||
|
reward_func, evaluation_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the SP group
|
# Initialize the SP group
|
||||||
self.sp_group = get_ring_attn_group()
|
self.sp_group = get_ring_attn_group()
|
||||||
@@ -255,6 +617,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||||
|
|
||||||
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||||
|
all_prompts_text = gather_object(prompts_text)
|
||||||
|
|
||||||
# Generate completions using either vLLM or regular generation
|
# Generate completions using either vLLM or regular generation
|
||||||
if self.args.use_vllm:
|
if self.args.use_vllm:
|
||||||
# First, have main process load weights if needed
|
# First, have main process load weights if needed
|
||||||
@@ -262,14 +627,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
self._move_model_to_vllm()
|
self._move_model_to_vllm()
|
||||||
self._last_loaded_step = self.state.global_step
|
self._last_loaded_step = self.state.global_step
|
||||||
|
|
||||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
|
||||||
all_prompts_text = gather_object(prompts_text)
|
|
||||||
|
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||||
# prompt individually.
|
# prompt individually.
|
||||||
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
|
# ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
|
||||||
|
ordered_set_of_prompts = all_prompts_text[
|
||||||
|
:: self.num_generations * self.args.sequence_parallel_degree
|
||||||
|
]
|
||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
completion_ids = self.vllm_client.generate(
|
completion_ids = self.vllm_client.generate(
|
||||||
prompts=ordered_set_of_prompts,
|
prompts=ordered_set_of_prompts,
|
||||||
@@ -297,33 +662,19 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||||
|
|
||||||
# All ranks in the same SP group get the same data slice
|
# All ranks in the same SP group get the same data slice
|
||||||
# This ensures identical inputs for sequence-parallel processing
|
|
||||||
process_slice = slice(
|
process_slice = slice(
|
||||||
sp_group_start,
|
sp_group_start,
|
||||||
sp_group_start + len(prompts) * self.local_world_size,
|
sp_group_start + len(prompts),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Take the full SP group's worth of completions
|
|
||||||
completion_ids = completion_ids[process_slice]
|
completion_ids = completion_ids[process_slice]
|
||||||
else:
|
else:
|
||||||
# Original behavior for non-sequence-parallel case
|
# Original behavior for non-sequence parallel case
|
||||||
process_slice = slice(
|
process_slice = slice(
|
||||||
self.accelerator.process_index * len(prompts),
|
self.accelerator.process_index * len(prompts),
|
||||||
(self.accelerator.process_index + 1) * len(prompts),
|
(self.accelerator.process_index + 1) * len(prompts),
|
||||||
)
|
)
|
||||||
completion_ids = completion_ids[process_slice]
|
completion_ids = completion_ids[process_slice]
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
dist.barrier()
|
|
||||||
if dist.get_rank() == 1:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# Pad the completions, and concatenate them with the prompts
|
# Pad the completions, and concatenate them with the prompts
|
||||||
completion_ids = [
|
completion_ids = [
|
||||||
torch.tensor(ids, device=device) for ids in completion_ids
|
torch.tensor(ids, device=device) for ids in completion_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user