Files
axolotl/tests/prompt_strategies/test_stepwise.py
salman 54dd7abfc1 Process reward models (#2241)
* adding model_cfg to set num_labels

* using a num_labels field instead

* linting

* WIP stepwise prompt tokenizer

* this should work?

* trainer working?

* pushing to runpod

* fixing saving

* updating conf

* updating config, adding docs

* adding stepwise supervision docpage

* updating tests

* adding test for dataset

* fixing tests

* linting

* addressing some comments

* adding additional cfg fields support

* updating tests, fixing cfg

* fixing tests

* updating loss

* Update test_process_reward_model_smollm2.py

* updating loss values and seed

* dumb pre-commit
2025-01-29 00:08:33 -05:00

64 lines
2.2 KiB
Python

"""
tests for chat_template prompt strategy
"""
import datasets
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.stepwise_supervised import (
StepwiseSupervisedPromptTokenizingStrategy,
)
class TestStepWiseSupervisedPromptTokenizingStrategy:
"""
Test class for stepwise supervised prompt strategy
"""
@pytest.fixture()
def stepwise_supervised_dataset(self):
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.",
"Actually, this is incorrect. In decimal numbers, 0.8 is equal to 0.80, which is larger than 0.11. Therefore, 9.8 is larger than 9.11.",
],
"labels": [True, False, False],
}
]
)
@pytest.fixture()
def tokenizer(self):
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
def test_stepwise_supervised_dataset(self, tokenizer, stepwise_supervised_dataset):
strategy = StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
sequence_len=2048,
step_separator="\n",
)
stepwise_supervised_dataset = stepwise_supervised_dataset.cast_column(
"labels", datasets.Sequence(datasets.Value("int64"))
)
dataset_wrapper = TokenizedPromptDataset(
strategy,
stepwise_supervised_dataset,
process_count=1,
)
labels = dataset_wrapper[0]["labels"]
# expected labels is:
# the prompt + first step are ignored, followed by the label for step 1 (True)
# the second step, and its label (False)
# the third step, and its label (False)
expected = [-100] * 47 + [1] + [-100] * 29 + [0] + [-100] * 48 + [0]
assert labels == expected