Files
axolotl/src/axolotl/prompt_strategies/stepwise_supervised.py
salman 54dd7abfc1 Process reward models (#2241)
* adding model_cfg to set num_labels

* using a num_labels field instead

* linting

* WIP stepwise prompt tokenizer

* this should work?

* trainer working?

* pushing to runpod

* fixing saving

* updating conf

* updating config, adding docs

* adding stepwise supervision docpage

* updating tests

* adding test for dataset

* fixing tests

* linting

* addressing some comments

* adding additional cfg fields support

* updating tests, fixing cfg

* fixing tests

* updating loss

* Update test_process_reward_model_smollm2.py

* updating loss values and seed

* dumb pre-commit
2025-01-29 00:08:33 -05:00

117 lines
3.9 KiB
Python

"""
Module for stepwise datasets, typically including a prompt and reasoning traces,
and (optionally) per-step, or per-prompt-trace labels for reward modelling.
"""
from itertools import chain
from typing import Dict, List, Optional, Union
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompt_tokenizers import IGNORE_INDEX
from axolotl.utils.dict import DictDefault
class StepwiseSupervisedPromptTokenizingStrategy:
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
- prompt: the prompt text
- completions: a list of `n` completion steps
- labels: a list of `n` labels indicating the "correctness" of each step
"""
def __init__(
self,
tokenizer,
sequence_len: int = 2048,
step_separator: str = "\n",
max_completion_length: Optional[int] = None,
train_on_last_step_only: bool = False,
):
self.tokenizer = tokenizer
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]
completions_ids = [
self.tokenizer(completion, add_special_tokens=False)["input_ids"]
for completion in prompt["completions"]
]
# Handle labels
if self.train_on_last_step_only:
labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [
int(prompt["labels"][-1])
]
else:
labels = [int(label) for label in prompt["labels"]]
# Add step separators
separator_ids = self.tokenizer.encode(
self.step_separator, add_special_tokens=False
)
completions_ids = [completion + separator_ids for completion in completions_ids]
# Create step-wise labels
labels = [
[IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore
for completion, label in zip(completions_ids, labels)
]
# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels)) # type: ignore
# Handle max lengths
if self.max_completion_length:
completion_ids = completion_ids[: self.max_completion_length]
labels = labels[: self.max_completion_length]
# Add BOS token if model has one
if self.tokenizer.bos_token_id is not None:
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids
# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels
# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
full_labels = full_labels[: self.sequence_len]
return {
"input_ids": input_ids,
"labels": full_labels,
"attention_mask": [1] * len(input_ids),
}
@property
def supports_batched(self):
return False
def load(
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
ds_cfg: DictDefault,
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.sequence_len,
step_separator=ds_cfg.get("step_separator", "\n"),
max_completion_length=ds_cfg.max_completion_length,
train_on_last_step_only=ds_cfg.get("train_on_last_step_only", False),
)