Compare commits
5 Commits
liger-dpo
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed |
@@ -12,7 +12,7 @@ liger-kernel==0.4.2
|
|||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.47.0
|
transformers>=4.46.3
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
accelerate==1.2.0
|
accelerate==1.2.0
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,220 +0,0 @@
|
|||||||
"""
|
|
||||||
extra axolotl specific training args
|
|
||||||
"""
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
|
||||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlTrainingMixins:
|
|
||||||
"""
|
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
|
||||||
)
|
|
||||||
lr_quadratic_warmup: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
|
||||||
)
|
|
||||||
pretraining: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
|
||||||
)
|
|
||||||
multipack_real_batches: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
|
||||||
)
|
|
||||||
eval_sample_packing: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Use sample packing for efficient evals."},
|
|
||||||
)
|
|
||||||
sample_packing_efficiency: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
|
||||||
)
|
|
||||||
sample_packing_bin_size: int = field(
|
|
||||||
default=200,
|
|
||||||
metadata={
|
|
||||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing_group_size: int = field(
|
|
||||||
default=100000,
|
|
||||||
metadata={
|
|
||||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
|
||||||
)
|
|
||||||
relora_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_prune_ratio: Optional[float] = field(
|
|
||||||
default=0.9,
|
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
|
||||||
)
|
|
||||||
bench_split: Optional[str] = field(
|
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
|
||||||
)
|
|
||||||
bench_dataset: Optional[str] = field(
|
|
||||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
|
||||||
metadata={
|
|
||||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
do_bench_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
|
||||||
)
|
|
||||||
do_causal_lm_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
|
||||||
)
|
|
||||||
max_bench_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
bench_source_max_len: int = field(
|
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
|
||||||
)
|
|
||||||
dataloader_prefetch_factor: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
|
||||||
)
|
|
||||||
cosine_min_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
|
||||||
)
|
|
||||||
cosine_constant_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
loraplus_lr_ratio: Optional[float] = field(
|
|
||||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
|
||||||
)
|
|
||||||
loraplus_lr_embedding: Optional[float] = field(
|
|
||||||
default=1e-6,
|
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
|
||||||
)
|
|
||||||
embedding_lr_scale: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
embedding_lr: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
qlora: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "whether this is a qlora training"},
|
|
||||||
)
|
|
||||||
orpo_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
lisa_n_layers: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "the number of activate layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_step_interval: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to switch layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_layers_attribute: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "path under the model to access the layers"},
|
|
||||||
)
|
|
||||||
curriculum_sampling: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
|
||||||
)
|
|
||||||
alternate_optimizer: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
alternate_lr_scheduler_type: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
chat_template: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Chat template converting chat messages to text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|
||||||
"""
|
|
||||||
Training arguments for Causal trainer
|
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
|
||||||
so it can't be used as a mixin.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
|
||||||
"""
|
|
||||||
ORPO config for ORPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
|
||||||
"""
|
|
||||||
KTO config for KTO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
|
||||||
"""
|
|
||||||
CPO config for CPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
simpo_gamma: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "simpo gamma parameter"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
|
||||||
"""
|
|
||||||
Reward config for Reward training
|
|
||||||
"""
|
|
||||||
@@ -36,8 +36,6 @@ class LigerArgs(BaseModel):
|
|||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
liger_pref_rl: Optional[bool] = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_deprecated_swiglu(cls, data):
|
def check_deprecated_swiglu(cls, data):
|
||||||
|
|||||||
@@ -1,253 +0,0 @@
|
|||||||
"""
|
|
||||||
integration of liger dpo kernels with dpotrainer
|
|
||||||
"""
|
|
||||||
from typing import Dict, List, Literal, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
|
||||||
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
|
||||||
from torch import nn
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlDPOTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the DPO Trainer to use LIGER kernels for DPO
|
|
||||||
"""
|
|
||||||
|
|
||||||
def concatenated_forward(
|
|
||||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
|
|
||||||
and compute the DPO loss using Liger's fused kernel.
|
|
||||||
|
|
||||||
This method replaces the original `concatenated_forward` implementation to use Liger.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Prepare concatenated inputs
|
|
||||||
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
|
|
||||||
|
|
||||||
# Extract concatenated inputs
|
|
||||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
|
||||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
|
||||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
|
||||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
|
||||||
|
|
||||||
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
|
|
||||||
# This example assumes a causal decoder-only model.
|
|
||||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Align inputs by removing leading padding
|
|
||||||
for i in range(attention_mask.size(0)):
|
|
||||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
|
||||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
|
|
||||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
|
|
||||||
|
|
||||||
# Remove trailing empty columns
|
|
||||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
|
||||||
if empty_cols.any():
|
|
||||||
first_empty_col = torch.nonzero(empty_cols)[0].item()
|
|
||||||
input_ids = input_ids[:, :first_empty_col]
|
|
||||||
attention_mask = attention_mask[:, :first_empty_col]
|
|
||||||
|
|
||||||
if self.args.max_length is not None:
|
|
||||||
input_ids = input_ids[:, : self.args.max_length]
|
|
||||||
attention_mask = attention_mask[:, : self.args.max_length]
|
|
||||||
|
|
||||||
# Labels are completion_input_ids shifted by one token right
|
|
||||||
# For causal LM, labels are the completion part only
|
|
||||||
labels = torch.cat(
|
|
||||||
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
|
|
||||||
)
|
|
||||||
labels = labels[:, 1:] # shift left by one
|
|
||||||
attention_mask = attention_mask[:, 1:]
|
|
||||||
labels = labels[:, : attention_mask.size(1)]
|
|
||||||
|
|
||||||
# Mask out the prompt portion from loss
|
|
||||||
labels[~attention_mask.bool()] = self.label_pad_token_id
|
|
||||||
|
|
||||||
# Prepare reference model hidden states if ref_model exists
|
|
||||||
use_ref_model = self.ref_model is not None and not self.reference_free
|
|
||||||
|
|
||||||
# Run main model forward to get hidden states
|
|
||||||
# If using FSDP, redirect forward calls
|
|
||||||
if isinstance(model, FullyShardedDataParallel):
|
|
||||||
outputs = _FSDPForwardRedirection()(
|
|
||||||
model,
|
|
||||||
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If model is a DataParallel, unwrap
|
|
||||||
if isinstance(model, torch.nn.DataParallel):
|
|
||||||
model = model.module
|
|
||||||
outputs = model.model(
|
|
||||||
input_ids, attention_mask=attention_mask, use_cache=False
|
|
||||||
)
|
|
||||||
|
|
||||||
last_hidden_state = outputs.last_hidden_state
|
|
||||||
|
|
||||||
ref_last_hidden_state = None
|
|
||||||
if use_ref_model:
|
|
||||||
ref_model = self.ref_model
|
|
||||||
if isinstance(ref_model, FullyShardedDataParallel):
|
|
||||||
with torch.no_grad():
|
|
||||||
ref_outputs = _FSDPForwardRedirection()(
|
|
||||||
ref_model,
|
|
||||||
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if isinstance(ref_model, torch.nn.DataParallel):
|
|
||||||
ref_model = ref_model.module
|
|
||||||
with torch.no_grad():
|
|
||||||
ref_outputs = ref_model.model(
|
|
||||||
input_ids, attention_mask=attention_mask, use_cache=False
|
|
||||||
)
|
|
||||||
ref_last_hidden_state = ref_outputs.last_hidden_state
|
|
||||||
|
|
||||||
# Retrieve lm_head parameters
|
|
||||||
lm_head = model.lm_head
|
|
||||||
ref_lm_head = (
|
|
||||||
self.ref_model.lm_head
|
|
||||||
if (use_ref_model and self.ref_model is not None)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use Liger fused DPO loss
|
|
||||||
dpo_loss_fn = LigerFusedLinearDPOLoss(
|
|
||||||
ignore_index=self.label_pad_token_id,
|
|
||||||
beta=self.beta,
|
|
||||||
compute_nll_loss=False,
|
|
||||||
compiled=True,
|
|
||||||
use_ref_model=use_ref_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# call fused Liger DPO
|
|
||||||
if use_ref_model:
|
|
||||||
loss_acc, aux_outputs = dpo_loss_fn(
|
|
||||||
lm_head.weight, # lin_weight
|
|
||||||
last_hidden_state, # _input
|
|
||||||
labels, # target
|
|
||||||
bias=lm_head.bias,
|
|
||||||
ref_input=ref_last_hidden_state,
|
|
||||||
ref_weight=ref_lm_head.weight,
|
|
||||||
ref_bias=ref_lm_head.bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits_mean,
|
|
||||||
policy_rejected_logits_mean,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = aux_outputs[:5]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# No reference model scenario: Liger kernel treats ref_logps as 0
|
|
||||||
loss_acc, aux_outputs = dpo_loss_fn(
|
|
||||||
lm_head.weight,
|
|
||||||
last_hidden_state,
|
|
||||||
labels,
|
|
||||||
bias=lm_head.bias,
|
|
||||||
)
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits_mean,
|
|
||||||
policy_rejected_logits_mean,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = aux_outputs[:5]
|
|
||||||
|
|
||||||
# Add aux loss if enabled
|
|
||||||
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
|
|
||||||
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
|
|
||||||
|
|
||||||
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
|
|
||||||
if self.args.rpo_alpha is not None:
|
|
||||||
# policy_nll_loss: average negative log-likelihood of chosen completions
|
|
||||||
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
|
|
||||||
|
|
||||||
return (
|
|
||||||
loss_acc,
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits_mean,
|
|
||||||
policy_rejected_logits_mean,
|
|
||||||
policy_nll_loss,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
|
|
||||||
"""
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
(
|
|
||||||
loss,
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits_mean,
|
|
||||||
policy_rejected_logits_mean,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = self.concatenated_forward(model, batch)
|
|
||||||
|
|
||||||
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
|
|
||||||
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
|
|
||||||
if self.ref_model is not None and not self.reference_free:
|
|
||||||
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
|
|
||||||
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
|
|
||||||
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
|
|
||||||
#
|
|
||||||
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
|
|
||||||
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
|
|
||||||
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
|
|
||||||
|
|
||||||
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
|
|
||||||
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
|
|
||||||
chosen_rewards = policy_chosen_logps * self.beta # approximation
|
|
||||||
rejected_rewards = policy_rejected_logps * self.beta # approximation
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
|
|
||||||
metrics[f"{train_eval}_rewards/rejected"] = (
|
|
||||||
rejected_rewards.mean().cpu().item()
|
|
||||||
)
|
|
||||||
metrics[f"{train_eval}_rewards/accuracies"] = (
|
|
||||||
reward_accuracies.mean().cpu().item()
|
|
||||||
)
|
|
||||||
metrics[f"{train_eval}_rewards/margins"] = (
|
|
||||||
(chosen_rewards - rejected_rewards).mean().cpu().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
|
|
||||||
metrics[f"{train_eval}_logps/rejected"] = (
|
|
||||||
policy_rejected_logps.mean().cpu().item()
|
|
||||||
)
|
|
||||||
metrics[f"{train_eval}_logits/chosen"] = (
|
|
||||||
policy_chosen_logits_mean.detach().cpu().item()
|
|
||||||
)
|
|
||||||
metrics[f"{train_eval}_logits/rejected"] = (
|
|
||||||
policy_rejected_logits_mean.detach().cpu().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.rpo_alpha is not None:
|
|
||||||
metrics[f"{train_eval}_nll_loss"] = (
|
|
||||||
policy_nll_loss.mean().detach().cpu().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss.mean(), metrics
|
|
||||||
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
HF_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_MODEL_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
LCE_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_LCE_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_forward(cls):
|
||||||
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<forward>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_act_offloading(cls):
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
|
)
|
||||||
|
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
|
from liger_kernel.transformers.model.llama import (
|
||||||
|
lce_forward as llama_lce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_act_offloading:
|
||||||
|
lce_source = inspect.getsource(llama_lce_forward)
|
||||||
|
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_hf_ga(cls):
|
||||||
|
# bugfix patch for gradient accumulation
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||||
|
)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_auto_model():
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
AxolotlLlamaForCausalLM.set_forward()
|
||||||
|
|
||||||
|
return AxolotlLlamaForCausalLM
|
||||||
@@ -8,36 +8,17 @@ def argilla(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
if "prompt" in sample.keys():
|
|
||||||
prompt_key = "prompt"
|
|
||||||
elif "input" in sample.keys():
|
|
||||||
prompt_key = "input"
|
|
||||||
elif "question" in sample.keys():
|
|
||||||
prompt_key = "question"
|
|
||||||
else:
|
|
||||||
prompt_key = "instruction"
|
|
||||||
|
|
||||||
if "chosen" in sample.keys():
|
|
||||||
chosen_key = "chosen"
|
|
||||||
else:
|
|
||||||
chosen_key = "chosen_response"
|
|
||||||
|
|
||||||
if "rejected" in sample.keys():
|
|
||||||
rejected_key = "rejected"
|
|
||||||
else:
|
|
||||||
rejected_key = "rejected_response"
|
|
||||||
|
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
sample["prompt"] = (
|
sample["prompt"] = (
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample[
|
||||||
"prompt"
|
"prompt"
|
||||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return transform_fn
|
return transform_fn
|
||||||
|
|||||||
@@ -8,37 +8,17 @@ def argilla(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
if "prompt" in sample.keys():
|
|
||||||
prompt_key = "prompt"
|
|
||||||
elif "input" in sample.keys():
|
|
||||||
prompt_key = "input"
|
|
||||||
elif "question" in sample.keys():
|
|
||||||
prompt_key = "question"
|
|
||||||
else:
|
|
||||||
prompt_key = "instruction"
|
|
||||||
|
|
||||||
if "chosen" in sample.keys():
|
|
||||||
chosen_key = "chosen"
|
|
||||||
else:
|
|
||||||
chosen_key = "chosen_response"
|
|
||||||
|
|
||||||
if "rejected" in sample.keys():
|
|
||||||
rejected_key = "rejected"
|
|
||||||
else:
|
|
||||||
rejected_key = "rejected_response"
|
|
||||||
|
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
sample["prompt"] = (
|
sample["prompt"] = (
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample[
|
||||||
"prompt"
|
"prompt"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return transform_fn
|
return transform_fn
|
||||||
|
|||||||
@@ -679,6 +679,7 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
activation_offloading: Optional[bool] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||||
|
if self.cfg.activation_offloading:
|
||||||
|
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
patch_training_loop_for_fsdp,
|
patch_training_loop_for_fsdp,
|
||||||
@@ -1183,6 +1192,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
|
# self.apply_patches_to_model()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user