diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9409b1ef1..f91b9554d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: isort - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 - repo: https://github.com/PyCQA/pylint diff --git a/docs/config.qmd b/docs/config.qmd index f253decbe..ecb571040 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -187,6 +187,12 @@ rl: # whether to perform weighting if doing DPO training. Boolean. dpo_use_weighting: +# reward modelling: `True` or `False` +reward_model: + +# process reward modelling: `True` or `False` +process_reward_model: + # The name of the chat template to use for training, following values are supported: # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py diff --git a/docs/dataset-formats/stepwise_supervised.qmd b/docs/dataset-formats/stepwise_supervised.qmd new file mode 100644 index 000000000..17f0c9141 --- /dev/null +++ b/docs/dataset-formats/stepwise_supervised.qmd @@ -0,0 +1,18 @@ +--- +title: Stepwise Supervised Format +description: Format for datasets with stepwise completions and labels +order: 3 +--- + +## Stepwise Supervised + +The stepwise supervised format is designed for chain-of-thought (COT) reasoning datasets where each example contains multiple completion steps and a preference label for each step. +### ExampleHere's a simple example of a stepwise supervised dataset entry:```json +{ + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": [ + "The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", + "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8." + ], + "labels": [true, false] +} diff --git a/docs/reward_modelling.qmd b/docs/reward_modelling.qmd new file mode 100644 index 000000000..8baa93424 --- /dev/null +++ b/docs/reward_modelling.qmd @@ -0,0 +1,47 @@ +--- +title: "Reward Modelling" +description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. " +--- + +### Overview + +Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions. +We support the reward modelling techniques supported by `trl`. + +### (Outcome) Reward Models + +Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step). + +```yaml +base_model: google/gemma-2-2b +model_type: AutoModelForSequenceClassification +num_labels: 1 +tokenizer_type: AutoTokenizer + +reward_model: true +chat_template: gemma +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template + +val_set_size: 0.1 +eval_steps: 100 +``` + +### Process Reward Models (PRM) + +Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning. +```yaml +base_model: Qwen/Qwen2.5-3B +model_type: AutoModelForTokenClassification +num_labels: 2 + +process_reward_model: true +datasets: + - path: trl-lib/math_shepherd + type: stepwise_supervised + split: train + +val_set_size: 0.1 +eval_steps: 100 +``` diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index b492c6f93..ada42ec28 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -1,6 +1,7 @@ base_model: google/gemma-2-2b # optionally might have model_type or tokenizer_type model_type: AutoModelForSequenceClassification +num_labels: 1 tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml new file mode 100644 index 000000000..071e2d0f3 --- /dev/null +++ b/examples/qwen2/prm.yaml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen2.5-3B +# optionally might have model_type or tokenizer_type +model_type: AutoModelForTokenClassification +num_labels: 2 +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +process_reward_model: true +chat_template: +datasets: + - path: trl-lib/math_shepherd + type: stepwise_supervised + step_separator: "\n" + max_completion_length: + train_on_last_step_only: false + +val_set_size: 0.2 +output_dir: ./outputs/out +remove_unused_columns: false + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +eval_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +eval_steps: 100 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml new file mode 100644 index 000000000..bbd6e66ce --- /dev/null +++ b/examples/qwen2/reward-model.yaml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-0.5B +# optionally might have model_type or tokenizer_type +model_type: AutoModelForSequenceClassification +num_labels: 1 +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +reward_model: true +chat_template: qwen_25 +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +remove_unused_columns: false + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d63a10e74..6bf03d78c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -44,6 +44,8 @@ from trl import ( KTOTrainer, ORPOConfig, ORPOTrainer, + PRMConfig, + PRMTrainer, RewardConfig, RewardTrainer, ) @@ -342,6 +344,13 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): """ +@dataclass +class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig): + """ + PRM config for PRM training + """ + + class SchedulerMixin(Trainer): """ Mixin class for scheduler setup in CausalTrainer. @@ -1244,6 +1253,14 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): tag_names = ["axolotl", "reward"] +class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): + """ + Extend the base trl.PRMTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "prm"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1377,7 +1394,8 @@ class TrainerBuilderBase(abc.ABC): class HFCausalTrainerBuilder(TrainerBuilderBase): """ - Build the HuggingFace training args/trainer for Causal models + Build the HuggingFace training args/trainer for causal models + and reward modelling using TRL. """ def get_callbacks(self): @@ -1452,6 +1470,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlMambaTrainer if self.cfg.reward_model: return AxolotlRewardTrainer + if self.cfg.process_reward_model: + return AxolotlPRMTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -1842,11 +1862,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "accelerator_config" ] = self.cfg.accelerator_config - training_args_cls = ( - AxolotlTrainingArguments - if not self.cfg.reward_model - else AxolotlRewardConfig - ) + if self.cfg.reward_model: + training_args_cls = AxolotlRewardConfig + elif self.cfg.process_reward_model: + training_args_cls = AxolotlPRMConfig + else: + training_args_cls = AxolotlTrainingArguments + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, ) @@ -1880,9 +1902,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if eval_data_collator := self.build_collator( training_args, is_eval=True, **data_collator_kwargs ): - if not self.cfg.reward_model: + if not (self.cfg.reward_model or self.cfg.process_reward_model): trainer_kwargs["eval_data_collator"] = eval_data_collator - if not self.cfg.reward_model: + if not (self.cfg.reward_model or self.cfg.process_reward_model): trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -1893,8 +1915,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs["processing_class"] = self.tokenizer else: trainer_kwargs["tokenizer"] = self.tokenizer - - if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None: + if ( + not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) + and self.cfg.datasets is not None + ): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] @@ -1984,7 +2008,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase): """ - Trainer factory class for DPO Trainer + Trainer factory class for TRL-based RLHF trainers (e.g. DPO) """ def get_callbacks(self): diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index b5638a614..e4531930f 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -52,6 +52,7 @@ class TokenizedPromptDataset(Dataset): if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True map_kwargs["batch_size"] = 100 + return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, diff --git a/src/axolotl/prompt_strategies/stepwise_supervised.py b/src/axolotl/prompt_strategies/stepwise_supervised.py new file mode 100644 index 000000000..8be7c35e3 --- /dev/null +++ b/src/axolotl/prompt_strategies/stepwise_supervised.py @@ -0,0 +1,116 @@ +""" +Module for stepwise datasets, typically including a prompt and reasoning traces, +and (optionally) per-step, or per-prompt-trace labels for reward modelling. +""" + +from itertools import chain +from typing import Dict, List, Optional, Union + +from transformers import BatchEncoding, PreTrainedTokenizer + +from axolotl.prompt_tokenizers import IGNORE_INDEX +from axolotl.utils.dict import DictDefault + + +class StepwiseSupervisedPromptTokenizingStrategy: + """ + Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning. + These datasets should include the following columns: + - prompt: the prompt text + - completions: a list of `n` completion steps + - labels: a list of `n` labels indicating the "correctness" of each step + """ + + def __init__( + self, + tokenizer, + sequence_len: int = 2048, + step_separator: str = "\n", + max_completion_length: Optional[int] = None, + train_on_last_step_only: bool = False, + ): + self.tokenizer = tokenizer + self.sequence_len = sequence_len + self.step_separator = step_separator + self.max_completion_length = max_completion_length + self.train_on_last_step_only = train_on_last_step_only + + def tokenize_prompt( + self, prompt: Dict[str, Union[str, List[str]]] + ) -> BatchEncoding: + # Inspired by TRL's PRMTRainer + # https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206 + prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[ + "input_ids" + ] + + completions_ids = [ + self.tokenizer(completion, add_special_tokens=False)["input_ids"] + for completion in prompt["completions"] + ] + + # Handle labels + if self.train_on_last_step_only: + labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [ + int(prompt["labels"][-1]) + ] + else: + labels = [int(label) for label in prompt["labels"]] + + # Add step separators + separator_ids = self.tokenizer.encode( + self.step_separator, add_special_tokens=False + ) + completions_ids = [completion + separator_ids for completion in completions_ids] + + # Create step-wise labels + labels = [ + [IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore + for completion, label in zip(completions_ids, labels) + ] + + # Join all steps + completion_ids = list(chain(*completions_ids)) + labels = list(chain(*labels)) # type: ignore + + # Handle max lengths + if self.max_completion_length: + completion_ids = completion_ids[: self.max_completion_length] + labels = labels[: self.max_completion_length] + + # Add BOS token if model has one + if self.tokenizer.bos_token_id is not None: + prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids + + # Combine prompt and completion + input_ids = prompt_ids + completion_ids + + full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels + # Apply max sequence length + if self.sequence_len: + input_ids = input_ids[: self.sequence_len] + full_labels = full_labels[: self.sequence_len] + + return { + "input_ids": input_ids, + "labels": full_labels, + "attention_mask": [1] * len(input_ids), + } + + @property + def supports_batched(self): + return False + + +def load( + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + ds_cfg: DictDefault, +) -> StepwiseSupervisedPromptTokenizingStrategy: + return StepwiseSupervisedPromptTokenizingStrategy( + tokenizer, + cfg.sequence_len, + step_separator=ds_cfg.get("step_separator", "\n"), + max_completion_length=ds_cfg.max_completion_length, + train_on_last_step_only=ds_cfg.get("train_on_last_step_only", False), + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b901c2a97..0bd400f6b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -259,7 +259,7 @@ def train( .decode("utf-8") } if cfg.datasets is not None: - if cfg.rl is not None or cfg.reward_model: + if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model: dataset_tags = [ d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() ] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e5edf8e7b..4f0fa4c29 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -236,6 +236,18 @@ class DPODataset(BaseModel): revision: Optional[str] = None +class StepwiseSupervisedDataset(BaseModel): + """Stepwise supervised dataset configuration subset""" + + path: Optional[str] = None + split: Optional[str] = None + data_files: Optional[List[str]] = None + revision: Optional[str] = None + step_separator: Optional[str] = None + max_completion_length: Optional[int] = None + train_on_last_step_only: Optional[bool] = None + + class UserDefinedKTOType(BaseModel): """User defined typing for KTO""" @@ -626,12 +638,14 @@ class AxolotlInputConfig( rl: Optional[RLType] = None reward_model: Optional[bool] = None + process_reward_model: Optional[bool] = None + num_labels: Optional[int] = None dpo_use_weighting: Optional[ bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. - datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore - test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore + datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore + test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index de373c06e..ba5d0c54d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -8,6 +8,8 @@ from typing import List, Tuple, Union from datasets import ( Dataset, DatasetDict, + Sequence, + Value, concatenate_datasets, load_dataset, load_from_disk, @@ -467,6 +469,17 @@ def get_dataset_wrapper( dataset, **ds_kwargs, ) + elif config_dataset.type.startswith("stepwise_supervised"): + dataset_prompter = UnsupportedPrompter() + ds_strategy = load(config_dataset.type, tokenizer, cfg, config_dataset) + # we need to explicitly cast boolean labels to int + # for compatibility with how trl's PRMTrainer works + dataset = dataset.cast_column("labels", Sequence(Value("int64"))) + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) elif ds_strategy := load( config_dataset.type, tokenizer, cfg, config_dataset, processor=processor ): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c4b8f05b9..d46564f42 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -138,7 +138,9 @@ def load_model_config(cfg): config_kwargs = {} if cfg.revision_of_model: config_kwargs["revision"] = cfg.revision_of_model - + if cfg.num_labels: + # num_labels is used to initialize classifier models + config_kwargs["num_labels"] = cfg.num_labels try: model_config = AutoConfig.from_pretrained( model_config_name, diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py new file mode 100644 index 000000000..16bf2cdc8 --- /dev/null +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -0,0 +1,69 @@ +""" +E2E tests for process reward model w/ lora llama +""" + +import logging +import os +import unittest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestProcessRewardSmolLM2(unittest.TestCase): + """ + Test case for Llama process reward models using LoRA + """ + + @with_temp_dir + def test_prm(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForTokenClassification", + "num_labels": 2, + "process_reward_model": True, + "sequence_len": 512, + "val_set_size": 0.0, + "datasets": [ + { + "path": "trl-lib/math_shepherd", + "type": "stepwise_supervised", + "step_separator": "\n", + "split": "train[:10%]", + }, + ], + "max_steps": 100, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0005, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "gradient_checkpointing": True, + "warmup_ratio": 0.1, + "use_tensorboard": True, + "special_tokens": {"pad_token": "<|endoftext|>"}, + "seed": 42, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" + ) + + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_smollm2.py similarity index 75% rename from tests/e2e/test_reward_model_llama.py rename to tests/e2e/test_reward_model_smollm2.py index 4cd8602f3..7360a99dc 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -12,25 +12,25 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestRewardModelLoraLlama(unittest.TestCase): +class TestRewardModelLoraSmolLM2(unittest.TestCase): """ Test case for Llama reward models using LoRA """ @with_temp_dir - def test_rm_fft(self, temp_dir): + def test_rm_lora(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { - "base_model": "JackFram/llama-68m", + "base_model": "HuggingFaceTB/SmolLM2-135M", "model_type": "AutoModelForSequenceClassification", - "tokenizer_type": "LlamaTokenizer", + "num_labels": 1, "chat_template": "alpaca", "reward_model": True, "sequence_len": 1024, @@ -42,16 +42,16 @@ class TestRewardModelLoraLlama(unittest.TestCase): "lora_target_linear": True, "val_set_size": 0.0, "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", + "pad_token": "<|endoftext|>", }, "datasets": [ { "path": "argilla/distilabel-intel-orca-dpo-pairs", "type": "bradley_terry.chat_template", + "split": "train[:10%]", }, ], + "lora_modules_to_save": ["embed_tokens", "lm_head"], "remove_unused_columns": False, "max_steps": 10, "num_epochs": 1, @@ -59,10 +59,11 @@ class TestRewardModelLoraLlama(unittest.TestCase): "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_bnb_8bit", + "optimizer": "adamw_torch", "lr_scheduler": "cosine", "gradient_checkpointing": True, "warmup_ratio": 0.1, + "use_tensorboard": True, } ) normalize_config(cfg) @@ -70,4 +71,7 @@ class TestRewardModelLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" + ) check_model_output_exists(temp_dir, cfg) diff --git a/tests/prompt_strategies/test_stepwise.py b/tests/prompt_strategies/test_stepwise.py new file mode 100644 index 000000000..2abe4ae18 --- /dev/null +++ b/tests/prompt_strategies/test_stepwise.py @@ -0,0 +1,63 @@ +""" +tests for chat_template prompt strategy +""" + +import datasets +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_strategies.stepwise_supervised import ( + StepwiseSupervisedPromptTokenizingStrategy, +) + + +class TestStepWiseSupervisedPromptTokenizingStrategy: + """ + Test class for stepwise supervised prompt strategy + """ + + @pytest.fixture() + def stepwise_supervised_dataset(self): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": [ + "The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", + "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.", + "Actually, this is incorrect. In decimal numbers, 0.8 is equal to 0.80, which is larger than 0.11. Therefore, 9.8 is larger than 9.11.", + ], + "labels": [True, False, False], + } + ] + ) + + @pytest.fixture() + def tokenizer(self): + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + + def test_stepwise_supervised_dataset(self, tokenizer, stepwise_supervised_dataset): + strategy = StepwiseSupervisedPromptTokenizingStrategy( + tokenizer, + sequence_len=2048, + step_separator="\n", + ) + stepwise_supervised_dataset = stepwise_supervised_dataset.cast_column( + "labels", datasets.Sequence(datasets.Value("int64")) + ) + dataset_wrapper = TokenizedPromptDataset( + strategy, + stepwise_supervised_dataset, + process_count=1, + ) + labels = dataset_wrapper[0]["labels"] + # expected labels is: + # the prompt + first step are ignored, followed by the label for step 1 (True) + # the second step, and its label (False) + # the third step, and its label (False) + expected = [-100] * 47 + [1] + [-100] * 29 + [0] + [-100] * 48 + [0] + + assert labels == expected