Compare commits

..

4 Commits

Author SHA1 Message Date
Wing Lian
96af760e08 add option for liger_pref_rl 2024-12-16 18:31:16 -05:00
Wing Lian
cfa80dace0 import typo 2024-12-16 14:27:26 -05:00
Wing Lian
0a661980ca wip for liger dpo integration 2024-12-16 14:16:36 -05:00
Wing Lian
effc4dc409 pin to 4.47.0 (#2180) 2024-12-12 20:17:12 -05:00
10 changed files with 1584 additions and 1259 deletions

View File

@@ -478,7 +478,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- model
```yaml
base_model: ./llama-7b-hf/ # local or huggingface repo
base_model: ./llama-7b-hf # local or huggingface repo
```
Note: The code will load the right architecture.

View File

@@ -12,7 +12,7 @@ liger-kernel==0.4.2
packaging==23.2
peft==0.14.0
transformers>=4.46.3
transformers==4.47.0
tokenizers>=0.20.1
accelerate==1.2.0
datasets==3.1.0

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,220 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""

View File

@@ -36,6 +36,8 @@ class LigerArgs(BaseModel):
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
liger_pref_rl: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):

View File

@@ -0,0 +1,253 @@
"""
integration of liger dpo kernels with dpotrainer
"""
from typing import Dict, List, Literal, Union
import torch
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel
from axolotl.core.trainers.base import AxolotlDPOTrainer
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
"""
Extend the DPO Trainer to use LIGER kernels for DPO
"""
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
):
"""
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
and compute the DPO loss using Liger's fused kernel.
This method replaces the original `concatenated_forward` implementation to use Liger.
"""
# Prepare concatenated inputs
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
# Extract concatenated inputs
prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_input_ids = concatenated_batch["completion_input_ids"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
# This example assumes a causal decoder-only model.
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat(
(prompt_attention_mask, completion_attention_mask), dim=1
)
# Align inputs by removing leading padding
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
# Remove trailing empty columns
empty_cols = torch.sum(attention_mask, dim=0) == 0
if empty_cols.any():
first_empty_col = torch.nonzero(empty_cols)[0].item()
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
# Labels are completion_input_ids shifted by one token right
# For causal LM, labels are the completion part only
labels = torch.cat(
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
)
labels = labels[:, 1:] # shift left by one
attention_mask = attention_mask[:, 1:]
labels = labels[:, : attention_mask.size(1)]
# Mask out the prompt portion from loss
labels[~attention_mask.bool()] = self.label_pad_token_id
# Prepare reference model hidden states if ref_model exists
use_ref_model = self.ref_model is not None and not self.reference_free
# Run main model forward to get hidden states
# If using FSDP, redirect forward calls
if isinstance(model, FullyShardedDataParallel):
outputs = _FSDPForwardRedirection()(
model,
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
input_ids,
attention_mask=attention_mask,
use_cache=False,
)
else:
# If model is a DataParallel, unwrap
if isinstance(model, torch.nn.DataParallel):
model = model.module
outputs = model.model(
input_ids, attention_mask=attention_mask, use_cache=False
)
last_hidden_state = outputs.last_hidden_state
ref_last_hidden_state = None
if use_ref_model:
ref_model = self.ref_model
if isinstance(ref_model, FullyShardedDataParallel):
with torch.no_grad():
ref_outputs = _FSDPForwardRedirection()(
ref_model,
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
input_ids,
attention_mask=attention_mask,
use_cache=False,
)
else:
if isinstance(ref_model, torch.nn.DataParallel):
ref_model = ref_model.module
with torch.no_grad():
ref_outputs = ref_model.model(
input_ids, attention_mask=attention_mask, use_cache=False
)
ref_last_hidden_state = ref_outputs.last_hidden_state
# Retrieve lm_head parameters
lm_head = model.lm_head
ref_lm_head = (
self.ref_model.lm_head
if (use_ref_model and self.ref_model is not None)
else None
)
# Use Liger fused DPO loss
dpo_loss_fn = LigerFusedLinearDPOLoss(
ignore_index=self.label_pad_token_id,
beta=self.beta,
compute_nll_loss=False,
compiled=True,
use_ref_model=use_ref_model,
)
# call fused Liger DPO
if use_ref_model:
loss_acc, aux_outputs = dpo_loss_fn(
lm_head.weight, # lin_weight
last_hidden_state, # _input
labels, # target
bias=lm_head.bias,
ref_input=ref_last_hidden_state,
ref_weight=ref_lm_head.weight,
ref_bias=ref_lm_head.bias,
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
) = aux_outputs[:5]
else:
# No reference model scenario: Liger kernel treats ref_logps as 0
loss_acc, aux_outputs = dpo_loss_fn(
lm_head.weight,
last_hidden_state,
labels,
bias=lm_head.bias,
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
) = aux_outputs[:5]
# Add aux loss if enabled
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
if self.args.rpo_alpha is not None:
# policy_nll_loss: average negative log-likelihood of chosen completions
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
return (
loss_acc,
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
)
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
"""
metrics = {}
(
loss,
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
if self.ref_model is not None and not self.reference_free:
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
#
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
chosen_rewards = policy_chosen_logps * self.beta # approximation
rejected_rewards = policy_rejected_logps * self.beta # approximation
reward_accuracies = (chosen_rewards > rejected_rewards).float()
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
metrics[f"{train_eval}_rewards/rejected"] = (
rejected_rewards.mean().cpu().item()
)
metrics[f"{train_eval}_rewards/accuracies"] = (
reward_accuracies.mean().cpu().item()
)
metrics[f"{train_eval}_rewards/margins"] = (
(chosen_rewards - rejected_rewards).mean().cpu().item()
)
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
metrics[f"{train_eval}_logps/rejected"] = (
policy_rejected_logps.mean().cpu().item()
)
metrics[f"{train_eval}_logits/chosen"] = (
policy_chosen_logits_mean.detach().cpu().item()
)
metrics[f"{train_eval}_logits/rejected"] = (
policy_rejected_logits_mean.detach().cpu().item()
)
if self.args.rpo_alpha is not None:
metrics[f"{train_eval}_nll_loss"] = (
policy_nll_loss.mean().detach().cpu().item()
)
return loss.mean(), metrics

View File

@@ -8,17 +8,36 @@ def argilla(
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample
return transform_fn

View File

@@ -8,17 +8,37 @@ def argilla(
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
# pylint: disable=duplicate-code
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample
return transform_fn