Process reward models (#2241)

* adding model_cfg to set num_labels

* using a num_labels field instead

* linting

* WIP stepwise prompt tokenizer

* this should work?

* trainer working?

* pushing to runpod

* fixing saving

* updating conf

* updating config, adding docs

* adding stepwise supervision docpage

* updating tests

* adding test for dataset

* fixing tests

* linting

* addressing some comments

* adding additional cfg fields support

* updating tests, fixing cfg

* fixing tests

* updating loss

* Update test_process_reward_model_smollm2.py

* updating loss values and seed

* dumb pre-commit
This commit is contained in:
salman
2025-01-29 05:08:33 +00:00
committed by GitHub
parent c071a530f7
commit 54dd7abfc1
17 changed files with 542 additions and 25 deletions

View File

@@ -0,0 +1,69 @@
"""
E2E tests for process reward model w/ lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestProcessRewardSmolLM2(unittest.TestCase):
"""
Test case for Llama process reward models using LoRA
"""
@with_temp_dir
def test_prm(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForTokenClassification",
"num_labels": 2,
"process_reward_model": True,
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [
{
"path": "trl-lib/math_shepherd",
"type": "stepwise_supervised",
"step_separator": "\n",
"split": "train[:10%]",
},
],
"max_steps": 100,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0005,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -12,25 +12,25 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraLlama(unittest.TestCase):
class TestRewardModelLoraSmolLM2(unittest.TestCase):
"""
Test case for Llama reward models using LoRA
"""
@with_temp_dir
def test_rm_fft(self, temp_dir):
def test_rm_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForSequenceClassification",
"tokenizer_type": "LlamaTokenizer",
"num_labels": 1,
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
@@ -42,16 +42,16 @@ class TestRewardModelLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
"split": "train[:10%]",
},
],
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"remove_unused_columns": False,
"max_steps": 10,
"num_epochs": 1,
@@ -59,10 +59,11 @@ class TestRewardModelLoraLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
}
)
normalize_config(cfg)
@@ -70,4 +71,7 @@ class TestRewardModelLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,63 @@
"""
tests for chat_template prompt strategy
"""
import datasets
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.stepwise_supervised import (
StepwiseSupervisedPromptTokenizingStrategy,
)
class TestStepWiseSupervisedPromptTokenizingStrategy:
"""
Test class for stepwise supervised prompt strategy
"""
@pytest.fixture()
def stepwise_supervised_dataset(self):
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.",
"Actually, this is incorrect. In decimal numbers, 0.8 is equal to 0.80, which is larger than 0.11. Therefore, 9.8 is larger than 9.11.",
],
"labels": [True, False, False],
}
]
)
@pytest.fixture()
def tokenizer(self):
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
def test_stepwise_supervised_dataset(self, tokenizer, stepwise_supervised_dataset):
strategy = StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
sequence_len=2048,
step_separator="\n",
)
stepwise_supervised_dataset = stepwise_supervised_dataset.cast_column(
"labels", datasets.Sequence(datasets.Value("int64"))
)
dataset_wrapper = TokenizedPromptDataset(
strategy,
stepwise_supervised_dataset,
process_count=1,
)
labels = dataset_wrapper[0]["labels"]
# expected labels is:
# the prompt + first step are ignored, followed by the label for step 1 (True)
# the second step, and its label (False)
# the third step, and its label (False)
expected = [-100] * 47 + [1] + [-100] * 29 + [0] + [-100] * 48 + [0]
assert labels == expected