From ed2e8cacd6cc7acba38582412311a2d6052bc1cf Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 7 Nov 2025 19:21:40 +0530 Subject: [PATCH] feat:openenv rollout_func (#3239) [skip ci] * feat:openenv rollout_func * chore lint * docs * add:docs processing_class * tests * lint --- docs/rlhf.qmd | 112 +++++++++++++++++++++ src/axolotl/core/trainers/grpo/__init__.py | 32 ++++++ src/axolotl/utils/schemas/trl.py | 6 ++ tests/utils/test_grpo_rw_fnc.py | 18 ++++ 4 files changed, 168 insertions(+) create mode 100644 tests/utils/test_grpo_rw_fnc.py diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 594ebc743..2033649cc 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -597,6 +597,118 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py). +#### OpenEnv Rollout Functions +```bash +pip insatll openenv-core``` + +GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments. + +For example, to implement a simple math-solving environment with step-by-step verification: + +```python +# math_env.py +import re + +def math_solver_rollout(model, processing_class, prompts, generation_config=None): + """ + Custom rollout function that generates step-by-step math solutions. + + Args: + model: The language model + processing_class: The tokenizer/processing_class + prompts: List of prompt dicts (with 'messages' key for chat format) + generation_config: Optional generation configuration + + Returns: + List of completion strings + """ + completions = [] + + for prompt in prompts: + # Apply chat template to prompt + messages = prompt.get("messages", []) + formatted_prompt = processing_class.apply_chat_template( + messages, processing_class=False, add_generation_prompt=True + ) + + # Generate step-by-step solution + full_response = "" + for step in range(5): # Max 5 reasoning steps + current_input = formatted_prompt + full_response + "\nNext step:" + inputs = processing_class(current_input, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=100, + generation_config=generation_config, + ) + step_text = processing_class.decode( + outputs[0][inputs.input_ids.shape[1]:], + skip_special_tokens=True + ) + + # Check if solution is complete + if "FINAL ANSWER:" in step_text: + full_response += step_text + break + full_response += step_text + "\n" + + completions.append(full_response) + + return completions + +def math_reward(prompts, completions, answers, **kwargs): + """Reward function that checks mathematical correctness""" + rewards = [] + for completion, correct_answer in zip(completions, answers): + # Extract predicted answer + match = re.search(r"FINAL ANSWER:\s*(.+)", completion) + predicted = match.group(1).strip() if match else "" + + # Compare with correct answer + reward = 1.0 if predicted == str(correct_answer) else 0.0 + rewards.append(reward) + + return rewards + +def math_transform(cfg, *args, **kwargs): + """Transform dataset to GRPO format with answer field""" + def transform_fn(example, processing_class=None): + return { + "prompt": [{"role": "user", "content": example["question"]}], + "answer": str(example["answer"]), + } + return transform_fn, {"remove_columns": ["question"]} +``` + +```yaml +rl: grpo + +trl: + beta: 0.001 + max_completion_length: 512 + num_generations: 4 + rollout_func: "math_env.math_solver_rollout" # Custom rollout function + reward_funcs: ["math_env.math_reward"] + reward_weights: [1.0] + +datasets: + - path: openai/gsm8k + name: main + type: math_env.math_transform +``` + +The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives: + +- `model`: The language model +- `processing_class`: The tokenizer/processing class +- `prompts`: List of prompt dictionaries +- `generation_config` (optional): Generation configuration + +And should return a list of completion strings. + +For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv). + #### GRPO with DAPO/Dr. GRPO loss The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses. diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index bd77489eb..7f28cb8d4 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -126,6 +126,9 @@ class GRPOStrategy: if trl.use_liger_loss is not None: grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss + if trl.rollout_func: + grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) + return grpo_args_kwargs @classmethod @@ -201,3 +204,32 @@ class GRPOStrategy: raise ValueError( f"Reward function {reward_func_fqn} not found." ) from exc + + @classmethod + def get_rollout_func(cls, rollout_func_fqn: str): + """ + Returns the rollout function from the given fully qualified name. + + Args: + rollout_func_fqn (str): Fully qualified name of the rollout function + (e.g. my_module.my_rollout_func) + + Returns: + Callable rollout function + """ + try: + rollout_func_module_name = rollout_func_fqn.split(".")[-1] + rollout_func_module = importlib.import_module( + ".".join(rollout_func_fqn.split(".")[:-1]) + ) + rollout_func = getattr(rollout_func_module, rollout_func_module_name) + + if not callable(rollout_func): + raise ValueError( + f"Rollout function {rollout_func_fqn} must be callable" + ) + + return rollout_func + + except ModuleNotFoundError as exc: + raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 624f7663e..d24d6f477 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -173,3 +173,9 @@ class TRLConfig(BaseModel): "description": "Enable sleep mode for vLLM to offload VRAM when idle" }, ) + rollout_func: str | None = Field( + default=None, + json_schema_extra={ + "description": "Path to custom rollout function. Must be importable from current dir." + }, + ) diff --git a/tests/utils/test_grpo_rw_fnc.py b/tests/utils/test_grpo_rw_fnc.py new file mode 100644 index 000000000..507de277b --- /dev/null +++ b/tests/utils/test_grpo_rw_fnc.py @@ -0,0 +1,18 @@ +import os + +import pytest + +from axolotl.core.trainers.grpo import GRPOStrategy + + +def test_get_rollout_func_loads_successfully(): + """Test that a valid rollout function can be loaded""" + rollout_func = GRPOStrategy.get_rollout_func("os.path.join") + assert callable(rollout_func) + assert rollout_func == os.path.join + + +def test_get_rollout_func_invalid_module_raises_error(): + """Test that invalid module path raises clear ValueError""" + with pytest.raises(ValueError, match="Rollout function .* not found"): + GRPOStrategy.get_rollout_func("nonexistent_module.my_func")