Compare commits

...

13 Commits

Author SHA1 Message Date
Sung Ching Liu
f8e92407ff Update src/axolotl/common/datasets.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sung Ching Liu
c12906134d Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
8154d26614 nit 2025-04-17 09:47:14 -04:00
Sunny Liu
fefcbc300d barebone-ify the test so we get rid of unneeded processes 2025-04-17 09:47:14 -04:00
Sunny Liu
7d479348ee custom reward function loading, proeprly done 2025-04-17 09:47:14 -04:00
bursteratom
ce0259db13 add outputdir 2025-04-17 09:47:14 -04:00
Sung Ching Liu
2798817cf9 Update tests/e2e/solo/test_grpo.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-17 09:47:14 -04:00
Sunny Liu
0e1b081e49 add unit test 2025-04-17 09:47:14 -04:00
Sunny Liu
8df37ad91f propoer import from file_path after all else fails 2025-04-17 09:47:14 -04:00
Sung Ching Liu
9b74298328 Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
ae8738aa87 skip check_datasets_label during debug for grpo 2025-04-17 09:47:14 -04:00
Sunny Liu
ec52561a0c import from filepath if can't import_module 2025-04-17 09:47:14 -04:00
Sunny Liu
eadb16c709 test import-wihtin-import relative path 2025-04-17 09:47:14 -04:00
3 changed files with 159 additions and 29 deletions

View File

@@ -129,17 +129,19 @@ def load_preference_datasets(
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
if not cfg.rl == "grpo":
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,

View File

@@ -4,30 +4,73 @@ module for base dataset transform strategies
import importlib
import logging
import sys
LOG = logging.getLogger("axolotl")
def import_from_path(module_name: str, file_path: str):
"""
Import a module from a file path.
Args:
module_name: Name of the module.
file_path: Path to the file.
Returns:
module: The imported module.
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Could not create module spec for: {file_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
spec.loader = loader
loader.exec_module(module)
return module
def load(strategy, cfg, module_base=None, **kwargs):
try:
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
if len(strategy.split(".")) > 1:
try:
importlib.import_module(
strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]),
)
module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-2]
except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1])
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
func = None
if len(strategy.split(".")) > 1:
try:
mod = importlib.import_module(
strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]),
)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except ModuleNotFoundError:
pass
try:
mod = importlib.import_module(
"." + ".".join(strategy.split(".")[:-1]), module_base
)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except ModuleNotFoundError:
pass
try:
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
module_name = strategy.split(".")[-2]
mod = import_from_path(module_name, file_path)
func = getattr(mod, load_fn)
if func is not None:
return func(cfg, **kwargs)
except FileNotFoundError:
pass
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
LOG.warning(f"unable to load strategy {strategy}")
return func

View File

@@ -0,0 +1,85 @@
"""
E2E tests for preprocessing
"""
import logging
import os
import unittest
import transformers
from axolotl.cli.args import PreprocessCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomRewardFunctionLoading(unittest.TestCase):
"""
Test case for GRPO training using single GPU
"""
def _utils_write_rewards(self):
# write cfg to yaml file
with open("rewards.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@with_temp_dir
def test_custom_rewards_fn_preprocess(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"strict": False,
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
"rewards.rand_reward_func"
], # format: '{file_name}.{fn_name}'
"reward_weights": [1.0],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": "rewards.oai_gsm8k_transform",
},
],
"dataset_prepared_path": temp_dir,
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"learning_rate": 0.000005,
}
)
self._utils_write_rewards()
cfg = validate_config(cfg)
normalize_config(cfg)
parser = transformers.HfArgumentParser(PreprocessCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
load_preference_datasets(cfg=cfg, cli_args=cli_args)