custom reward function loading, proeprly done
This commit is contained in:
committed by
Sung Ching Liu
parent
ce0259db13
commit
7d479348ee
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
E2E tests for GRPO
|
||||
E2E tests for preprocessing
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
import transformers
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.common.datasets import load_preference_datasets
|
||||
from axolotl.cli.preprocess import do_preprocess
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -19,7 +19,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestGRPO(unittest.TestCase):
|
||||
class TestCustomRewardFunctionLoading(unittest.TestCase):
|
||||
"""
|
||||
Test case for GRPO training using single GPU
|
||||
"""
|
||||
@@ -48,19 +48,22 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"rl": "grpo",
|
||||
"trl": [
|
||||
{
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [
|
||||
"rewards.rand_reward_func"
|
||||
], # format: '{file_name}.{fn_name}'
|
||||
"reward_weights": [1.0],
|
||||
},
|
||||
],
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [
|
||||
"rewards.rand_reward_func"
|
||||
], # format: '{file_name}.{fn_name}'
|
||||
"reward_weights": [1.0],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
@@ -68,7 +71,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"type": "rewards.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 1,
|
||||
"learning_rate": 0.000005,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -79,4 +85,4 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
do_preprocess(cfg, cli_args)
|
||||
Reference in New Issue
Block a user