support for QAT w RL (DPO) (#2776)
This commit is contained in:
@@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def load_datasets(
|
||||
|
||||
|
||||
def load_preference_datasets(
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||
) -> TrainDatasetMeta:
|
||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||
@@ -118,16 +118,19 @@ def load_preference_datasets(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
if (cli_args and cli_args.debug) or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
dataset=train_samples,
|
||||
tokenizer=tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from axolotl.core.training_args import (
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -31,6 +32,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -138,22 +142,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
|
||||
# Not compatible with IPO
|
||||
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = (
|
||||
self.cfg.dpo_use_logits_to_keep
|
||||
)
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
|
||||
@@ -22,10 +22,19 @@ class DPOStrategy:
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
|
||||
@@ -83,3 +83,20 @@ class AxolotlDPOTrainer(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: nn.Module,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
is_ref_model: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
|
||||
return res
|
||||
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
|
||||
@@ -46,6 +46,14 @@ def default(
|
||||
)
|
||||
|
||||
messages = sample[field_messages]
|
||||
if isinstance(messages, str):
|
||||
messages = [
|
||||
{
|
||||
message_property_mappings["role"]: "user",
|
||||
message_property_mappings["content"]: messages,
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
@@ -53,13 +61,35 @@ def default(
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
|
||||
chosen_raw = sample[field_chosen]
|
||||
if isinstance(chosen_raw, str):
|
||||
chosen_msg = {
|
||||
message_property_mappings["role"]: "assistant",
|
||||
message_property_mappings["content"]: chosen_raw,
|
||||
}
|
||||
elif isinstance(chosen_raw, dict):
|
||||
chosen_msg = chosen_raw
|
||||
else:
|
||||
chosen_msg = chosen_raw[-1]
|
||||
chosen = {
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
"role": role_map[chosen_msg[message_property_mappings["role"]]],
|
||||
"content": chosen_msg[message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
rejected_raw = sample[field_rejected]
|
||||
if isinstance(rejected_raw, str):
|
||||
rejected_msg = {
|
||||
message_property_mappings["role"]: "assistant",
|
||||
message_property_mappings["content"]: rejected_raw,
|
||||
}
|
||||
elif isinstance(rejected_raw, dict):
|
||||
rejected_msg = rejected_raw
|
||||
else:
|
||||
rejected_msg = rejected_raw[-1]
|
||||
rejected = {
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
"role": role_map[rejected_msg[message_property_mappings["role"]]],
|
||||
"content": rejected_msg[message_property_mappings["content"]],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
@@ -102,6 +102,8 @@ class AxolotlInputConfig(
|
||||
dpo_use_weighting: bool | None = None
|
||||
dpo_use_logits_to_keep: bool | None = None
|
||||
dpo_label_smoothing: float | None = None
|
||||
dpo_norm_loss: bool | None = None
|
||||
dpo_padding_free: bool | None = None
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
|
||||
@@ -2,24 +2,22 @@
|
||||
E2E tests for QAT
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestQATLlama(unittest.TestCase):
|
||||
class TestQATLlama:
|
||||
"""
|
||||
Test case for QAT Llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_qat_lora(self, temp_dir):
|
||||
def test_qat(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -67,3 +65,69 @@ class TestQATLlama(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
def test_qat_dpo(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int8",
|
||||
"group_size": 8,
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_preference_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
loss_threshold = 2.3
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user