Atropos support (#2666) [skip ci]

* allow peft+liger+grpo and custom vllm serve for atropos support

* set trainer class for RL
This commit is contained in:
Wing Lian
2025-05-13 08:30:58 -04:00
committed by GitHub
parent 80304c26a7
commit 7fa1089cea
4 changed files with 25 additions and 13 deletions

View File

@@ -82,6 +82,12 @@ class VllmServeCliArgs:
"hardware support this feature." "hardware support this feature."
}, },
) )
serve_module: Optional[str] = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
},
)
@dataclass @dataclass

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union from typing import Union
from trl.scripts.vllm_serve import ScriptArguments from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
@@ -28,6 +27,9 @@ def do_vllm_serve(
cfg = load_cfg(config) cfg = load_cfg(config)
model = cfg.base_model model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = ( tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
) )

View File

@@ -1197,6 +1197,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
sig = inspect.signature(trainer_cls) sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys(): if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer trainer_kwargs["tokenizer"] = self.tokenizer

View File

@@ -1149,18 +1149,18 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before") # @model_validator(mode="before")
@classmethod # @classmethod
def check_grpo_peft_liger(cls, data): # def check_grpo_peft_liger(cls, data):
if ( # if (
data.get("rl") == "grpo" # data.get("rl") == "grpo"
and data.get("trl", {}) # and data.get("trl", {})
and data.get("trl").get("use_liger_loss") # and data.get("trl").get("use_liger_loss")
and data.get("adapter") # and data.get("adapter")
): # ):
raise ValueError("PEFT + GRPO + Liger is not yet supported") # raise ValueError("PEFT + GRPO + Liger is not yet supported")
return data # return data
#
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_grpo_liger_sequence_parallel(cls, data): def check_grpo_liger_sequence_parallel(cls, data):