async grpo support

This commit is contained in:
Wing Lian
2026-03-09 22:59:16 +00:00
parent cf4d550c88
commit f0c9e98699
6 changed files with 1389 additions and 6 deletions

View File

@@ -54,8 +54,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
@@ -151,7 +153,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
training_args_cls = GRPOStrategy.get_training_args_class(async_grpo=async_grpo)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:

View File

@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
@@ -27,14 +28,18 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
cls, sequence_parallel: bool, async_grpo: bool = False,
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer] | type[AxolotlAsyncGRPOTrainer]:
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
def get_training_args_class(cls, async_grpo: bool = False) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
if async_grpo:
return AxolotlAsyncGRPOConfig
return AxolotlGRPOConfig
@classmethod
@@ -131,6 +136,28 @@ class GRPOStrategy:
trl.multi_objective_aggregation
)
# Async GRPO fields
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
if getattr(trl, "vllm_sync_interval", None) is not None:
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "streaming_partial_batch", None) is not None:
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
if getattr(trl, "streaming_min_groups", None) is not None:
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_correction"] = trl.vllm_importance_sampling_correction
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_mode"] = trl.vllm_importance_sampling_mode
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_cap"] = trl.vllm_importance_sampling_cap
if getattr(trl, "off_policy_mask_threshold", None) is not None:
grpo_args_kwargs["off_policy_mask_threshold"] = trl.off_policy_mask_threshold
if getattr(trl, "use_bias_correction_kl", None) is not None:
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
return grpo_args_kwargs
@classmethod

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOConfig
@dataclass
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
context_parallel_size: int | None = None
@dataclass
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, AsyncGRPOConfig):
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
context_parallel_size: int | None = None

View File

@@ -34,6 +34,8 @@ from trl.data_utils import (
is_conversational,
maybe_apply_chat_template,
)
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOTrainer
from trl.extras.profiling import profiling_context
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
@@ -66,6 +68,19 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
class AxolotlAsyncGRPOTrainer(
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
AsyncGRPOTrainer,
):
"""Extend AsyncGRPOTrainer with axolotl helpers (async prefetch, streaming, IS correction)."""
_tag_names = ["trl", "grpo", "async-grpo", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

File diff suppressed because it is too large Load Diff

View File

@@ -189,3 +189,67 @@ class TRLConfig(BaseModel):
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)
# Async GRPO fields
async_prefetch: bool = Field(
default=False,
json_schema_extra={
"description": "Generate rollouts in a background thread while training on the previous rollout."
},
)
prefetch_depth: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of rollouts to prefetch ahead of training."
},
)
vllm_sync_interval: int | None = Field(
default=None,
json_schema_extra={
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
},
)
streaming_partial_batch: bool | None = Field(
default=None,
json_schema_extra={
"description": "Score prompt groups incrementally instead of the full batch at once."
},
)
streaming_min_groups: int | None = Field(
default=None,
json_schema_extra={
"description": "Minimum prompt groups to score per streaming chunk."
},
)
vllm_importance_sampling_correction: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
},
)
vllm_importance_sampling_mode: (
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] | None
) = Field(
default=None,
json_schema_extra={
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
},
)
vllm_importance_sampling_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Cap C for IS ratio clipping/masking."
},
)
off_policy_mask_threshold: float | None = Field(
default=None,
json_schema_extra={
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
},
)
use_bias_correction_kl: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply IS correction to KL divergence term."
},
)