custom reward function loading, proeprly done
This commit is contained in:
committed by
Sung Ching Liu
parent
ce0259db13
commit
7d479348ee
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
E2E tests for GRPO
|
E2E tests for preprocessing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.common.datasets import load_preference_datasets
|
from axolotl.cli.preprocess import do_preprocess
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestGRPO(unittest.TestCase):
|
class TestCustomRewardFunctionLoading(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for GRPO training using single GPU
|
Test case for GRPO training using single GPU
|
||||||
"""
|
"""
|
||||||
@@ -48,19 +48,22 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"rl": "grpo",
|
"rl": "grpo",
|
||||||
"trl": [
|
"trl": {
|
||||||
{
|
"beta": 0.001,
|
||||||
"beta": 0.001,
|
"max_completion_length": 256,
|
||||||
"max_completion_length": 256,
|
"use_vllm": True,
|
||||||
"use_vllm": True,
|
"num_generations": 4,
|
||||||
"num_generations": 4,
|
"reward_funcs": [
|
||||||
"reward_funcs": [
|
"rewards.rand_reward_func"
|
||||||
"rewards.rand_reward_func"
|
], # format: '{file_name}.{fn_name}'
|
||||||
], # format: '{file_name}.{fn_name}'
|
"reward_weights": [1.0],
|
||||||
"reward_weights": [1.0],
|
},
|
||||||
},
|
"vllm": {
|
||||||
],
|
"max_model_len": 800,
|
||||||
|
"enable_prefix_caching": True,
|
||||||
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "openai/gsm8k",
|
"path": "openai/gsm8k",
|
||||||
@@ -68,7 +71,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"type": "rewards.oai_gsm8k_transform",
|
"type": "rewards.oai_gsm8k_transform",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"output_dir": temp_dir,
|
"dataset_prepared_path": temp_dir,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"learning_rate": 0.000005,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -79,4 +85,4 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
do_preprocess(cfg, cli_args)
|
||||||
Reference in New Issue
Block a user