Atropos support (#2666) [skip ci]
* allow peft+liger+grpo and custom vllm serve for atropos support * set trainer class for RL
This commit is contained in:
@@ -82,6 +82,12 @@ class VllmServeCliArgs:
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
serve_module: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Module to serve. If not set, the default module will be used."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
@@ -28,6 +27,9 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
|
||||
@@ -1197,6 +1197,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
@@ -1149,18 +1149,18 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_peft_liger(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("adapter")
|
||||
):
|
||||
raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||
return data
|
||||
|
||||
# @model_validator(mode="before")
|
||||
# @classmethod
|
||||
# def check_grpo_peft_liger(cls, data):
|
||||
# if (
|
||||
# data.get("rl") == "grpo"
|
||||
# and data.get("trl", {})
|
||||
# and data.get("trl").get("use_liger_loss")
|
||||
# and data.get("adapter")
|
||||
# ):
|
||||
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||
# return data
|
||||
#
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_liger_sequence_parallel(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user