custom reward function loading, proeprly done

This commit is contained in:
Sunny Liu
2025-04-08 14:40:20 -04:00
committed by Sung Ching Liu
parent ce0259db13
commit 7d479348ee

View File

@@ -1,5 +1,5 @@
""" """
E2E tests for GRPO E2E tests for preprocessing
""" """
import logging import logging
@@ -9,7 +9,7 @@ import unittest
import transformers import transformers
from axolotl.cli.args import PreprocessCliArgs from axolotl.cli.args import PreprocessCliArgs
from axolotl.common.datasets import load_preference_datasets from axolotl.cli.preprocess import do_preprocess
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -19,7 +19,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestGRPO(unittest.TestCase): class TestCustomRewardFunctionLoading(unittest.TestCase):
""" """
Test case for GRPO training using single GPU Test case for GRPO training using single GPU
""" """
@@ -48,19 +48,22 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"rl": "grpo", "rl": "grpo",
"trl": [ "trl": {
{ "beta": 0.001,
"beta": 0.001, "max_completion_length": 256,
"max_completion_length": 256, "use_vllm": True,
"use_vllm": True, "num_generations": 4,
"num_generations": 4, "reward_funcs": [
"reward_funcs": [ "rewards.rand_reward_func"
"rewards.rand_reward_func" ], # format: '{file_name}.{fn_name}'
], # format: '{file_name}.{fn_name}' "reward_weights": [1.0],
"reward_weights": [1.0], },
}, "vllm": {
], "max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [ "datasets": [
{ {
"path": "openai/gsm8k", "path": "openai/gsm8k",
@@ -68,7 +71,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"type": "rewards.oai_gsm8k_transform", "type": "rewards.oai_gsm8k_transform",
}, },
], ],
"output_dir": temp_dir, "dataset_prepared_path": temp_dir,
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"learning_rate": 0.000005,
} }
) )
@@ -79,4 +85,4 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
load_preference_datasets(cfg=cfg, cli_args=cli_args) do_preprocess(cfg, cli_args)