From 7fa1089cea4e81d2724dee23dae5ad23dfb10399 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 13 May 2025 08:30:58 -0400 Subject: [PATCH] Atropos support (#2666) [skip ci] * allow peft+liger+grpo and custom vllm serve for atropos support * set trainer class for RL --- src/axolotl/cli/args.py | 6 ++++++ src/axolotl/cli/vllm_serve.py | 4 +++- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/schemas/config.py | 24 ++++++++++++------------ 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 83febd7f4..088e337e4 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -82,6 +82,12 @@ class VllmServeCliArgs: "hardware support this feature." }, ) + serve_module: Optional[str] = field( + default=None, + metadata={ + "help": "Module to serve. If not set, the default module will be used." + }, + ) @dataclass diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index 552f33e9e..d3c4ad68d 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Union from trl.scripts.vllm_serve import ScriptArguments -from trl.scripts.vllm_serve import main as vllm_serve_main from axolotl.cli.config import load_cfg @@ -28,6 +27,9 @@ def do_vllm_serve( cfg = load_cfg(config) model = cfg.base_model + serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") + vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") + tensor_parallel_size = ( cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 99ab397c7..25d327dcd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1197,6 +1197,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): trainer_kwargs["tokenizer"] = self.tokenizer diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 34f084f10..25c802959 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1149,18 +1149,18 @@ class AxolotlInputConfig( return data - @model_validator(mode="before") - @classmethod - def check_grpo_peft_liger(cls, data): - if ( - data.get("rl") == "grpo" - and data.get("trl", {}) - and data.get("trl").get("use_liger_loss") - and data.get("adapter") - ): - raise ValueError("PEFT + GRPO + Liger is not yet supported") - return data - + # @model_validator(mode="before") + # @classmethod + # def check_grpo_peft_liger(cls, data): + # if ( + # data.get("rl") == "grpo" + # and data.get("trl", {}) + # and data.get("trl").get("use_liger_loss") + # and data.get("adapter") + # ): + # raise ValueError("PEFT + GRPO + Liger is not yet supported") + # return data + # @model_validator(mode="before") @classmethod def check_grpo_liger_sequence_parallel(cls, data):