From b2274d430b650cca6414e5e712d79a8df13c0ad6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Jun 2025 10:00:35 -0400 Subject: [PATCH] support for QAT w RL (DPO) (#2776) --- examples/llama-3/instruct-dpo-lora-8b.yml | 4 + src/axolotl/common/datasets.py | 13 ++-- src/axolotl/core/builders/rl.py | 21 ++--- src/axolotl/core/trainers/dpo/__init__.py | 11 ++- src/axolotl/core/trainers/dpo/args.py | 2 + src/axolotl/core/trainers/dpo/trainer.py | 17 +++++ .../prompt_strategies/dpo/chat_template.py | 38 +++++++++- src/axolotl/utils/schemas/config.py | 2 + tests/e2e/test_qat.py | 76 +++++++++++++++++-- 9 files changed, 152 insertions(+), 32 deletions(-) diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 13082294f..51f1c768b 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot_id|> + load_in_8bit: true load_in_4bit: false diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 4d64958b6..96af84c1e 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -95,7 +95,7 @@ def load_datasets( def load_preference_datasets( - *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs + *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None ) -> TrainDatasetMeta: """Loads one or more training or evaluation datasets for RL training using paired preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`. @@ -118,16 +118,19 @@ def load_preference_datasets( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - if cli_args.debug or cfg.debug: + if (cli_args and cli_args.debug) or cfg.debug: LOG.info("check_dataset_labels...") + num_examples = cli_args.debug_num_examples if cli_args else 1 + text_only = cli_args.debug_text_only if cli_args else False + tokenizer = load_tokenizer(cfg) - train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + train_samples = sample_dataset(train_dataset, num_examples) check_dataset_labels( dataset=train_samples, tokenizer=tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, + num_examples=num_examples, + text_only=text_only, rl_mode=True, ) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 80c5a9eef..47ace7451 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -19,6 +19,7 @@ from axolotl.core.training_args import ( ) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype +from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType @@ -31,6 +32,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() + if self.cfg.qat: + callbacks.append(QATCallback(self.cfg.qat)) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -138,22 +142,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in [RLType.DPO, RLType.IPO]: training_args_cls = AxolotlDPOConfig - if self.cfg.rl is RLType.IPO: - training_args_kwargs["loss_type"] = "ipo" - - # Not compatible with IPO - if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing: - training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing - - training_args_kwargs["max_completion_length"] = None - training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len - training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb - if self.cfg.dpo_use_weighting is not None: - training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting - if self.cfg.dpo_use_logits_to_keep is not None: - training_args_kwargs["use_logits_to_keep"] = ( - self.cfg.dpo_use_logits_to_keep - ) + training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg)) else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 603fdf0b6..8cd9aacf5 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -22,10 +22,19 @@ class DPOStrategy: training_args_kwargs = {} if cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" - training_args_kwargs["max_length"] = cfg.sequence_len + # Label smoothing is not compatible with IPO + if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing: + training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing training_args_kwargs["max_completion_length"] = None + training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_prompt_length"] = cfg.sequence_len training_args_kwargs["generate_during_eval"] = cfg.use_wandb if cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting + if cfg.dpo_padding_free is not None: + training_args_kwargs["padding_free"] = cfg.dpo_padding_free + if cfg.dpo_norm_loss is not None: + training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss + if cfg.dpo_use_logits_to_keep is not None: + training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index de1758ed0..b1e53236e 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): """ DPO config for DPO training """ + + dpo_norm_loss: bool | None = False diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 15af80c02..762e0a331 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -83,3 +83,20 @@ class AxolotlDPOTrainer( gc.collect() torch.cuda.empty_cache() return loss + + def concatenated_forward( + self, + model: nn.Module, + batch: dict[str, Union[list, torch.LongTensor]], + is_ref_model: bool = False, + ) -> dict[str, torch.Tensor]: + if self.args.dpo_norm_loss: + # fmt: off + loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition + # fmt: on + # concatenated_forward handles avg token logprob for ipo case already + self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init + res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model) + self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init + return res + return super().concatenated_forward(model, batch, is_ref_model=is_ref_model) diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index f3427022f..786770885 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -46,6 +46,14 @@ def default( ) messages = sample[field_messages] + if isinstance(messages, str): + messages = [ + { + message_property_mappings["role"]: "user", + message_property_mappings["content"]: messages, + } + ] + messages = [ { "role": role_map[m[message_property_mappings["role"]]], @@ -53,13 +61,35 @@ def default( } for m in messages ] + + chosen_raw = sample[field_chosen] + if isinstance(chosen_raw, str): + chosen_msg = { + message_property_mappings["role"]: "assistant", + message_property_mappings["content"]: chosen_raw, + } + elif isinstance(chosen_raw, dict): + chosen_msg = chosen_raw + else: + chosen_msg = chosen_raw[-1] chosen = { - "role": role_map[sample[field_chosen][message_property_mappings["role"]]], - "content": sample[field_chosen][message_property_mappings["content"]], + "role": role_map[chosen_msg[message_property_mappings["role"]]], + "content": chosen_msg[message_property_mappings["content"]], } + + rejected_raw = sample[field_rejected] + if isinstance(rejected_raw, str): + rejected_msg = { + message_property_mappings["role"]: "assistant", + message_property_mappings["content"]: rejected_raw, + } + elif isinstance(rejected_raw, dict): + rejected_msg = rejected_raw + else: + rejected_msg = rejected_raw[-1] rejected = { - "role": role_map[sample[field_rejected][message_property_mappings["role"]]], - "content": sample[field_rejected][message_property_mappings["content"]], + "role": role_map[rejected_msg[message_property_mappings["role"]]], + "content": rejected_msg[message_property_mappings["content"]], } dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 505d39858..33a8f77db 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -102,6 +102,8 @@ class AxolotlInputConfig( dpo_use_weighting: bool | None = None dpo_use_logits_to_keep: bool | None = None dpo_label_smoothing: float | None = None + dpo_norm_loss: bool | None = None + dpo_padding_free: bool | None = None datasets: ( Annotated[ diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index 2a7cd1459..964bf3c1c 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -2,24 +2,22 @@ E2E tests for QAT """ -import unittest from pathlib import Path -from axolotl.common.datasets import load_datasets +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard -class TestQATLlama(unittest.TestCase): +class TestQATLlama: """ Test case for QAT Llama models """ - @with_temp_dir - def test_qat_lora(self, temp_dir): + def test_qat(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -67,3 +65,69 @@ class TestQATLlama(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) + + def test_qat_dpo(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 2048, + "sample_packing": False, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "rl": "dpo", + "chat_template": "chatml", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "warmup_steps": 0, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "bf16": True, + "qat": { + "quantize_embedding": True, + "activation_dtype": "int8", + "weight_dtype": "int8", + "group_size": 8, + }, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_preference_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) + + loss_threshold = 2.3 + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss is too high", + )