diff --git a/tests/e2e/solo/test_grpo.py b/tests/e2e/solo/test_preprocess.py similarity index 65% rename from tests/e2e/solo/test_grpo.py rename to tests/e2e/solo/test_preprocess.py index 7ae7094ed..1fd7353c4 100644 --- a/tests/e2e/solo/test_grpo.py +++ b/tests/e2e/solo/test_preprocess.py @@ -1,5 +1,5 @@ """ -E2E tests for GRPO +E2E tests for preprocessing """ import logging @@ -9,7 +9,7 @@ import unittest import transformers from axolotl.cli.args import PreprocessCliArgs -from axolotl.common.datasets import load_preference_datasets +from axolotl.cli.preprocess import do_preprocess from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -19,7 +19,7 @@ LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestGRPO(unittest.TestCase): +class TestCustomRewardFunctionLoading(unittest.TestCase): """ Test case for GRPO training using single GPU """ @@ -48,19 +48,22 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): # pylint: disable=duplicate-code cfg = DictDefault( { + "base_model": "HuggingFaceTB/SmolLM2-135M", "rl": "grpo", - "trl": [ - { - "beta": 0.001, - "max_completion_length": 256, - "use_vllm": True, - "num_generations": 4, - "reward_funcs": [ - "rewards.rand_reward_func" - ], # format: '{file_name}.{fn_name}' - "reward_weights": [1.0], - }, - ], + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [ + "rewards.rand_reward_func" + ], # format: '{file_name}.{fn_name}' + "reward_weights": [1.0], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, "datasets": [ { "path": "openai/gsm8k", @@ -68,7 +71,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "type": "rewards.oai_gsm8k_transform", }, ], - "output_dir": temp_dir, + "dataset_prepared_path": temp_dir, + "gradient_accumulation_steps": 1, + "micro_batch_size": 1, + "learning_rate": 0.000005, } ) @@ -79,4 +85,4 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): parser = transformers.HfArgumentParser(PreprocessCliArgs) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - load_preference_datasets(cfg=cfg, cli_args=cli_args) + do_preprocess(cfg, cli_args)