custom reward function loading, proeprly done

This commit is contained in:
Sunny Liu
2025-04-08 14:40:20 -04:00
committed by Sung Ching Liu
parent ce0259db13
commit 7d479348ee

View File

@@ -1,5 +1,5 @@
"""
E2E tests for GRPO
E2E tests for preprocessing
"""
import logging
@@ -9,7 +9,7 @@ import unittest
import transformers
from axolotl.cli.args import PreprocessCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.cli.preprocess import do_preprocess
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -19,7 +19,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGRPO(unittest.TestCase):
class TestCustomRewardFunctionLoading(unittest.TestCase):
"""
Test case for GRPO training using single GPU
"""
@@ -48,19 +48,22 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"rl": "grpo",
"trl": [
{
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
"rewards.rand_reward_func"
], # format: '{file_name}.{fn_name}'
"reward_weights": [1.0],
},
],
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
"rewards.rand_reward_func"
], # format: '{file_name}.{fn_name}'
"reward_weights": [1.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
@@ -68,7 +71,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"type": "rewards.oai_gsm8k_transform",
},
],
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir,
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"learning_rate": 0.000005,
}
)
@@ -79,4 +85,4 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
parser = transformers.HfArgumentParser(PreprocessCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
load_preference_datasets(cfg=cfg, cli_args=cli_args)
do_preprocess(cfg, cli_args)