support for QAT w RL (DPO) (#2776)
This commit is contained in:
@@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer
|
|||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
eos_token: <|eot_id|>
|
||||||
|
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ def load_datasets(
|
|||||||
|
|
||||||
|
|
||||||
def load_preference_datasets(
|
def load_preference_datasets(
|
||||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
|
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||||
@@ -118,16 +118,19 @@ def load_preference_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if (cli_args and cli_args.debug) or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
dataset=train_samples,
|
dataset=train_samples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_examples=cli_args.debug_num_examples,
|
num_examples=num_examples,
|
||||||
text_only=cli_args.debug_text_only,
|
text_only=text_only,
|
||||||
rl_mode=True,
|
rl_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from axolotl.core.training_args import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
@@ -31,6 +32,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
|
||||||
|
if self.cfg.qat:
|
||||||
|
callbacks.append(QATCallback(self.cfg.qat))
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -138,22 +142,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rl is RLType.IPO:
|
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
training_args_kwargs["loss_type"] = "ipo"
|
|
||||||
|
|
||||||
# Not compatible with IPO
|
|
||||||
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
|
|
||||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
|
||||||
|
|
||||||
training_args_kwargs["max_completion_length"] = None
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
|
||||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
|
||||||
training_args_kwargs["use_logits_to_keep"] = (
|
|
||||||
self.cfg.dpo_use_logits_to_keep
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
|||||||
@@ -22,10 +22,19 @@ class DPOStrategy:
|
|||||||
training_args_kwargs = {}
|
training_args_kwargs = {}
|
||||||
if cfg.rl is RLType.IPO:
|
if cfg.rl is RLType.IPO:
|
||||||
training_args_kwargs["loss_type"] = "ipo"
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
# Label smoothing is not compatible with IPO
|
||||||
|
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||||
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
if cfg.dpo_padding_free is not None:
|
||||||
|
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||||
|
if cfg.dpo_norm_loss is not None:
|
||||||
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
|
if cfg.dpo_use_logits_to_keep is not None:
|
||||||
|
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
DPO config for DPO training
|
DPO config for DPO training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
dpo_norm_loss: bool | None = False
|
||||||
|
|||||||
@@ -83,3 +83,20 @@ class AxolotlDPOTrainer(
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
is_ref_model: bool = False,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
if self.args.dpo_norm_loss:
|
||||||
|
# fmt: off
|
||||||
|
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
# concatenated_forward handles avg token logprob for ipo case already
|
||||||
|
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
|
||||||
|
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||||
|
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
|
||||||
|
return res
|
||||||
|
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||||
|
|||||||
@@ -46,6 +46,14 @@ def default(
|
|||||||
)
|
)
|
||||||
|
|
||||||
messages = sample[field_messages]
|
messages = sample[field_messages]
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
message_property_mappings["role"]: "user",
|
||||||
|
message_property_mappings["content"]: messages,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": role_map[m[message_property_mappings["role"]]],
|
"role": role_map[m[message_property_mappings["role"]]],
|
||||||
@@ -53,13 +61,35 @@ def default(
|
|||||||
}
|
}
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
chosen_raw = sample[field_chosen]
|
||||||
|
if isinstance(chosen_raw, str):
|
||||||
|
chosen_msg = {
|
||||||
|
message_property_mappings["role"]: "assistant",
|
||||||
|
message_property_mappings["content"]: chosen_raw,
|
||||||
|
}
|
||||||
|
elif isinstance(chosen_raw, dict):
|
||||||
|
chosen_msg = chosen_raw
|
||||||
|
else:
|
||||||
|
chosen_msg = chosen_raw[-1]
|
||||||
chosen = {
|
chosen = {
|
||||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
"role": role_map[chosen_msg[message_property_mappings["role"]]],
|
||||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
"content": chosen_msg[message_property_mappings["content"]],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rejected_raw = sample[field_rejected]
|
||||||
|
if isinstance(rejected_raw, str):
|
||||||
|
rejected_msg = {
|
||||||
|
message_property_mappings["role"]: "assistant",
|
||||||
|
message_property_mappings["content"]: rejected_raw,
|
||||||
|
}
|
||||||
|
elif isinstance(rejected_raw, dict):
|
||||||
|
rejected_msg = rejected_raw
|
||||||
|
else:
|
||||||
|
rejected_msg = rejected_raw[-1]
|
||||||
rejected = {
|
rejected = {
|
||||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
"role": role_map[rejected_msg[message_property_mappings["role"]]],
|
||||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
"content": rejected_msg[message_property_mappings["content"]],
|
||||||
}
|
}
|
||||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||||
|
|
||||||
|
|||||||
@@ -102,6 +102,8 @@ class AxolotlInputConfig(
|
|||||||
dpo_use_weighting: bool | None = None
|
dpo_use_weighting: bool | None = None
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
|
dpo_norm_loss: bool | None = None
|
||||||
|
dpo_padding_free: bool | None = None
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
|
|||||||
@@ -2,24 +2,22 @@
|
|||||||
E2E tests for QAT
|
E2E tests for QAT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
|
|
||||||
class TestQATLlama(unittest.TestCase):
|
class TestQATLlama:
|
||||||
"""
|
"""
|
||||||
Test case for QAT Llama models
|
Test case for QAT Llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
def test_qat(self, temp_dir):
|
||||||
def test_qat_lora(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -67,3 +65,69 @@ class TestQATLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||||
|
|
||||||
|
def test_qat_dpo(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"sample_packing": False,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"warmup_steps": 0,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
"qat": {
|
||||||
|
"quantize_embedding": True,
|
||||||
|
"activation_dtype": "int8",
|
||||||
|
"weight_dtype": "int8",
|
||||||
|
"group_size": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_preference_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||||
|
|
||||||
|
loss_threshold = 2.3
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs",
|
||||||
|
"train/train_loss",
|
||||||
|
loss_threshold,
|
||||||
|
"Train Loss is too high",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user