Atropos support (#2666) [skip ci]

* allow peft+liger+grpo and custom vllm serve for atropos support

* set trainer class for RL
This commit is contained in:
Wing Lian
2025-05-13 08:30:58 -04:00
committed by GitHub
parent 80304c26a7
commit 7fa1089cea
4 changed files with 25 additions and 13 deletions

View File

@@ -82,6 +82,12 @@ class VllmServeCliArgs:
"hardware support this feature."
},
)
serve_module: Optional[str] = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
},
)
@dataclass

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
@@ -28,6 +27,9 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)

View File

@@ -1197,6 +1197,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer

View File

@@ -1149,18 +1149,18 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
return data
# @model_validator(mode="before")
# @classmethod
# def check_grpo_peft_liger(cls, data):
# if (
# data.get("rl") == "grpo"
# and data.get("trl", {})
# and data.get("trl").get("use_liger_loss")
# and data.get("adapter")
# ):
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
# return data
#
@model_validator(mode="before")
@classmethod
def check_grpo_liger_sequence_parallel(cls, data):