async grpo support
This commit is contained in:
@@ -54,8 +54,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
|
||||||
|
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
sequence_parallel=self.cfg.context_parallel_size > 1,
|
||||||
|
async_grpo=async_grpo,
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
@@ -151,7 +153,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
|
||||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
async_grpo = bool(self.cfg.trl and getattr(self.cfg.trl, "async_prefetch", False))
|
||||||
|
training_args_cls = GRPOStrategy.get_training_args_class(async_grpo=async_grpo)
|
||||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
if self.cfg.rl is RLType.GDPO:
|
if self.cfg.rl is RLType.GDPO:
|
||||||
|
|||||||
@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
|
|||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
|
||||||
from axolotl.core.trainers.grpo.trainer import (
|
from axolotl.core.trainers.grpo.trainer import (
|
||||||
|
AxolotlAsyncGRPOTrainer,
|
||||||
AxolotlGRPOSequenceParallelTrainer,
|
AxolotlGRPOSequenceParallelTrainer,
|
||||||
AxolotlGRPOTrainer,
|
AxolotlGRPOTrainer,
|
||||||
)
|
)
|
||||||
@@ -27,14 +28,18 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_trainer_class(
|
def get_trainer_class(
|
||||||
cls, sequence_parallel: bool
|
cls, sequence_parallel: bool, async_grpo: bool = False,
|
||||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer] | type[AxolotlAsyncGRPOTrainer]:
|
||||||
if sequence_parallel:
|
if sequence_parallel:
|
||||||
return AxolotlGRPOSequenceParallelTrainer
|
return AxolotlGRPOSequenceParallelTrainer
|
||||||
|
if async_grpo:
|
||||||
|
return AxolotlAsyncGRPOTrainer
|
||||||
return AxolotlGRPOTrainer
|
return AxolotlGRPOTrainer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
def get_training_args_class(cls, async_grpo: bool = False) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
|
||||||
|
if async_grpo:
|
||||||
|
return AxolotlAsyncGRPOConfig
|
||||||
return AxolotlGRPOConfig
|
return AxolotlGRPOConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -131,6 +136,28 @@ class GRPOStrategy:
|
|||||||
trl.multi_objective_aggregation
|
trl.multi_objective_aggregation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Async GRPO fields
|
||||||
|
if getattr(trl, "async_prefetch", None) is not None:
|
||||||
|
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
|
||||||
|
if getattr(trl, "prefetch_depth", None) is not None:
|
||||||
|
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
|
||||||
|
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||||
|
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||||
|
if getattr(trl, "streaming_partial_batch", None) is not None:
|
||||||
|
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
|
||||||
|
if getattr(trl, "streaming_min_groups", None) is not None:
|
||||||
|
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
|
||||||
|
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
|
||||||
|
grpo_args_kwargs["vllm_importance_sampling_correction"] = trl.vllm_importance_sampling_correction
|
||||||
|
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
|
||||||
|
grpo_args_kwargs["vllm_importance_sampling_mode"] = trl.vllm_importance_sampling_mode
|
||||||
|
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
|
||||||
|
grpo_args_kwargs["vllm_importance_sampling_cap"] = trl.vllm_importance_sampling_cap
|
||||||
|
if getattr(trl, "off_policy_mask_threshold", None) is not None:
|
||||||
|
grpo_args_kwargs["off_policy_mask_threshold"] = trl.off_policy_mask_threshold
|
||||||
|
if getattr(trl, "use_bias_correction_kl", None) is not None:
|
||||||
|
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||||||
from trl import GRPOConfig
|
from trl import GRPOConfig
|
||||||
|
|
||||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
|||||||
"""Axolotl GRPO Config for GRPO training"""
|
"""Axolotl GRPO Config for GRPO training"""
|
||||||
|
|
||||||
context_parallel_size: int | None = None
|
context_parallel_size: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, AsyncGRPOConfig):
|
||||||
|
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
|
||||||
|
|
||||||
|
context_parallel_size: int | None = None
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ from trl.data_utils import (
|
|||||||
is_conversational,
|
is_conversational,
|
||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOTrainer
|
||||||
from trl.extras.profiling import profiling_context
|
from trl.extras.profiling import profiling_context
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
from trl.trainer.grpo_config import GRPOConfig
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
@@ -66,6 +68,19 @@ class AxolotlGRPOTrainer(
|
|||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlAsyncGRPOTrainer(
|
||||||
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
):
|
||||||
|
"""Extend AsyncGRPOTrainer with axolotl helpers (async prefetch, streaming, IS correction)."""
|
||||||
|
|
||||||
|
_tag_names = ["trl", "grpo", "async-grpo", "axolotl"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|
||||||
|
|||||||
1266
src/axolotl/monkeypatch/trainer/async_grpo.py
Normal file
1266
src/axolotl/monkeypatch/trainer/async_grpo.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -189,3 +189,67 @@ class TRLConfig(BaseModel):
|
|||||||
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Async GRPO fields
|
||||||
|
async_prefetch: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Generate rollouts in a background thread while training on the previous rollout."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
prefetch_depth: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of rollouts to prefetch ahead of training."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vllm_sync_interval: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
streaming_partial_batch: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Score prompt groups incrementally instead of the full batch at once."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
streaming_min_groups: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Minimum prompt groups to score per streaming chunk."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vllm_importance_sampling_correction: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vllm_importance_sampling_mode: (
|
||||||
|
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] | None
|
||||||
|
) = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vllm_importance_sampling_cap: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Cap C for IS ratio clipping/masking."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
off_policy_mask_threshold: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
use_bias_correction_kl: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Apply IS correction to KL divergence term."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user