support for QAT w RL (DPO) (#2776)

This commit is contained in:
Wing Lian
2025-06-13 10:00:35 -04:00
committed by GitHub
parent eac4a61f55
commit b2274d430b
9 changed files with 152 additions and 32 deletions

View File

@@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
load_in_8bit: true
load_in_4bit: false

View File

@@ -95,7 +95,7 @@ def load_datasets(
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
) -> TrainDatasetMeta:
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
@@ -118,16 +118,19 @@ def load_preference_datasets(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
if (cli_args and cli_args.debug) or cfg.debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
dataset=train_samples,
tokenizer=tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
num_examples=num_examples,
text_only=text_only,
rl_mode=True,
)

View File

@@ -19,6 +19,7 @@ from axolotl.core.training_args import (
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -31,6 +32,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -138,22 +142,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Not compatible with IPO
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")

View File

@@ -22,10 +22,19 @@ class DPOStrategy:
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
return training_args_kwargs

View File

@@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
dpo_norm_loss: bool | None = False

View File

@@ -83,3 +83,20 @@ class AxolotlDPOTrainer(
gc.collect()
torch.cuda.empty_cache()
return loss
def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -46,6 +46,14 @@ def default(
)
messages = sample[field_messages]
if isinstance(messages, str):
messages = [
{
message_property_mappings["role"]: "user",
message_property_mappings["content"]: messages,
}
]
messages = [
{
"role": role_map[m[message_property_mappings["role"]]],
@@ -53,13 +61,35 @@ def default(
}
for m in messages
]
chosen_raw = sample[field_chosen]
if isinstance(chosen_raw, str):
chosen_msg = {
message_property_mappings["role"]: "assistant",
message_property_mappings["content"]: chosen_raw,
}
elif isinstance(chosen_raw, dict):
chosen_msg = chosen_raw
else:
chosen_msg = chosen_raw[-1]
chosen = {
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
"content": sample[field_chosen][message_property_mappings["content"]],
"role": role_map[chosen_msg[message_property_mappings["role"]]],
"content": chosen_msg[message_property_mappings["content"]],
}
rejected_raw = sample[field_rejected]
if isinstance(rejected_raw, str):
rejected_msg = {
message_property_mappings["role"]: "assistant",
message_property_mappings["content"]: rejected_raw,
}
elif isinstance(rejected_raw, dict):
rejected_msg = rejected_raw
else:
rejected_msg = rejected_raw[-1]
rejected = {
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
"content": sample[field_rejected][message_property_mappings["content"]],
"role": role_map[rejected_msg[message_property_mappings["role"]]],
"content": rejected_msg[message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}

View File

@@ -102,6 +102,8 @@ class AxolotlInputConfig(
dpo_use_weighting: bool | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None
datasets: (
Annotated[

View File

@@ -2,24 +2,22 @@
E2E tests for QAT
"""
import unittest
from pathlib import Path
from axolotl.common.datasets import load_datasets
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard
class TestQATLlama(unittest.TestCase):
class TestQATLlama:
"""
Test case for QAT Llama models
"""
@with_temp_dir
def test_qat_lora(self, temp_dir):
def test_qat(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -67,3 +65,69 @@ class TestQATLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
def test_qat_dpo(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"group_size": 8,
},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_preference_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)