async grpo support
This commit is contained in:
@@ -54,8 +54,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1,
|
||||
async_grpo=async_grpo,
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
@@ -151,7 +153,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
|
||||
training_args_cls = GRPOStrategy.get_training_args_class(async_grpo=async_grpo)
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
|
||||
@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
@@ -27,14 +28,18 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
cls, sequence_parallel: bool, async_grpo: bool = False,
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer] | type[AxolotlAsyncGRPOTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
||||
def get_training_args_class(cls, async_grpo: bool = False) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOConfig
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
@@ -131,6 +136,28 @@ class GRPOStrategy:
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
if getattr(trl, "async_prefetch", None) is not None:
|
||||
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "prefetch_depth", None) is not None:
|
||||
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "streaming_partial_batch", None) is not None:
|
||||
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
|
||||
if getattr(trl, "streaming_min_groups", None) is not None:
|
||||
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
|
||||
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_correction"] = trl.vllm_importance_sampling_correction
|
||||
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_mode"] = trl.vllm_importance_sampling_mode
|
||||
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_cap"] = trl.vllm_importance_sampling_cap
|
||||
if getattr(trl, "off_policy_mask_threshold", None) is not None:
|
||||
grpo_args_kwargs["off_policy_mask_threshold"] = trl.off_policy_mask_threshold
|
||||
if getattr(trl, "use_bias_correction_kl", None) is not None:
|
||||
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, AsyncGRPOConfig):
|
||||
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
@@ -34,6 +34,8 @@ from trl.data_utils import (
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOTrainer
|
||||
from trl.extras.profiling import profiling_context
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
@@ -66,6 +68,19 @@ class AxolotlGRPOTrainer(
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlAsyncGRPOTrainer(
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
AsyncGRPOTrainer,
|
||||
):
|
||||
"""Extend AsyncGRPOTrainer with axolotl helpers (async prefetch, streaming, IS correction)."""
|
||||
|
||||
_tag_names = ["trl", "grpo", "async-grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
|
||||
1266
src/axolotl/monkeypatch/trainer/async_grpo.py
Normal file
1266
src/axolotl/monkeypatch/trainer/async_grpo.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -189,3 +189,67 @@ class TRLConfig(BaseModel):
|
||||
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
||||
},
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
async_prefetch: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Generate rollouts in a background thread while training on the previous rollout."
|
||||
},
|
||||
)
|
||||
prefetch_depth: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of rollouts to prefetch ahead of training."
|
||||
},
|
||||
)
|
||||
vllm_sync_interval: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
|
||||
},
|
||||
)
|
||||
streaming_partial_batch: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Score prompt groups incrementally instead of the full batch at once."
|
||||
},
|
||||
)
|
||||
streaming_min_groups: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum prompt groups to score per streaming chunk."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_correction: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_mode: (
|
||||
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Cap C for IS ratio clipping/masking."
|
||||
},
|
||||
)
|
||||
off_policy_mask_threshold: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
|
||||
},
|
||||
)
|
||||
use_bias_correction_kl: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply IS correction to KL divergence term."
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user