From 5a36b6ff2dfa2cfd46fbe04dee4086c1ff344be8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 13 May 2025 08:30:58 -0400 Subject: [PATCH] Atropos support (#2666) [skip ci] * allow peft+liger+grpo and custom vllm serve for atropos support * set trainer class for RL --- src/axolotl/cli/args.py | 6 ++++++ src/axolotl/cli/vllm_serve.py | 4 +++- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/schemas/config.py | 18 +++++++++++++++--- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 83febd7f4..088e337e4 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -82,6 +82,12 @@ class VllmServeCliArgs: "hardware support this feature." }, ) + serve_module: Optional[str] = field( + default=None, + metadata={ + "help": "Module to serve. If not set, the default module will be used." + }, + ) @dataclass diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index 552f33e9e..d3c4ad68d 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Union from trl.scripts.vllm_serve import ScriptArguments -from trl.scripts.vllm_serve import main as vllm_serve_main from axolotl.cli.config import load_cfg @@ -28,6 +27,9 @@ def do_vllm_serve( cfg = load_cfg(config) model = cfg.base_model + serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") + vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") + tensor_parallel_size = ( cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 670561ede..2e8911a31 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1188,6 +1188,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): dpo_trainer_kwargs["tokenizer"] = self.tokenizer diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index cd9891e04..a618e1ae6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1149,16 +1149,28 @@ class AxolotlInputConfig( return data + # @model_validator(mode="before") + # @classmethod + # def check_grpo_peft_liger(cls, data): + # if ( + # data.get("rl") == "grpo" + # and data.get("trl", {}) + # and data.get("trl").get("use_liger_loss") + # and data.get("adapter") + # ): + # raise ValueError("PEFT + GRPO + Liger is not yet supported") + # return data + # @model_validator(mode="before") @classmethod - def check_grpo_peft_liger(cls, data): + def check_grpo_liger_sequence_parallel(cls, data): if ( data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") - and data.get("adapter") + and data.get("sequence_parallel_degree", 1) > 1 ): - raise ValueError("PEFT + GRPO + Liger is not yet supported") + raise ValueError("GRPO + SP + Liger not currently supported") return data @model_validator(mode="after")