Compare commits
3 Commits
quantize-p
...
liger-dpo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96af760e08 | ||
|
|
cfa80dace0 | ||
|
|
0a661980ca |
File diff suppressed because it is too large
Load Diff
1033
src/axolotl/core/trainers/base.py
Normal file
1033
src/axolotl/core/trainers/base.py
Normal file
File diff suppressed because it is too large
Load Diff
220
src/axolotl/core/training_args.py
Normal file
220
src/axolotl/core/training_args.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
extra axolotl specific training args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingMixins:
|
||||||
|
"""
|
||||||
|
Mixin class for the Axolotl training args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
|
lr_quadratic_warmup: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
|
)
|
||||||
|
pretraining: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
|
)
|
||||||
|
multipack_real_batches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
)
|
||||||
|
eval_sample_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
|
sample_packing_bin_size: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={
|
||||||
|
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing_group_size: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
|
)
|
||||||
|
relora_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_prune_ratio: Optional[float] = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
|
)
|
||||||
|
bench_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
|
)
|
||||||
|
bench_dataset: Optional[str] = field(
|
||||||
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
|
metadata={
|
||||||
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
|
)
|
||||||
|
do_causal_lm_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||||
|
)
|
||||||
|
max_bench_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
|
)
|
||||||
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||||
|
)
|
||||||
|
cosine_min_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||||
|
)
|
||||||
|
cosine_constant_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: Optional[float] = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr_scale: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
qlora: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "whether this is a qlora training"},
|
||||||
|
)
|
||||||
|
orpo_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
lisa_n_layers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
alternate_optimizer: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
alternate_lr_scheduler_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Chat template converting chat messages to text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
"""
|
||||||
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
|
so it can't be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
|
"""
|
||||||
|
ORPO config for ORPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||||
|
"""
|
||||||
|
KTO config for KTO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||||
|
"""
|
||||||
|
CPO config for CPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
simpo_gamma: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "simpo gamma parameter"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||||
|
"""
|
||||||
|
Reward config for Reward training
|
||||||
|
"""
|
||||||
@@ -36,6 +36,8 @@ class LigerArgs(BaseModel):
|
|||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
|
liger_pref_rl: Optional[bool] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_deprecated_swiglu(cls, data):
|
def check_deprecated_swiglu(cls, data):
|
||||||
|
|||||||
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""
|
||||||
|
integration of liger dpo kernels with dpotrainer
|
||||||
|
"""
|
||||||
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||||
|
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the DPO Trainer to use LIGER kernels for DPO
|
||||||
|
"""
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
|
||||||
|
and compute the DPO loss using Liger's fused kernel.
|
||||||
|
|
||||||
|
This method replaces the original `concatenated_forward` implementation to use Liger.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prepare concatenated inputs
|
||||||
|
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
|
||||||
|
|
||||||
|
# Extract concatenated inputs
|
||||||
|
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||||
|
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||||
|
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||||
|
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||||
|
|
||||||
|
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
|
||||||
|
# This example assumes a causal decoder-only model.
|
||||||
|
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Align inputs by removing leading padding
|
||||||
|
for i in range(attention_mask.size(0)):
|
||||||
|
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||||
|
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
|
||||||
|
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
|
||||||
|
|
||||||
|
# Remove trailing empty columns
|
||||||
|
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||||
|
if empty_cols.any():
|
||||||
|
first_empty_col = torch.nonzero(empty_cols)[0].item()
|
||||||
|
input_ids = input_ids[:, :first_empty_col]
|
||||||
|
attention_mask = attention_mask[:, :first_empty_col]
|
||||||
|
|
||||||
|
if self.args.max_length is not None:
|
||||||
|
input_ids = input_ids[:, : self.args.max_length]
|
||||||
|
attention_mask = attention_mask[:, : self.args.max_length]
|
||||||
|
|
||||||
|
# Labels are completion_input_ids shifted by one token right
|
||||||
|
# For causal LM, labels are the completion part only
|
||||||
|
labels = torch.cat(
|
||||||
|
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
|
||||||
|
)
|
||||||
|
labels = labels[:, 1:] # shift left by one
|
||||||
|
attention_mask = attention_mask[:, 1:]
|
||||||
|
labels = labels[:, : attention_mask.size(1)]
|
||||||
|
|
||||||
|
# Mask out the prompt portion from loss
|
||||||
|
labels[~attention_mask.bool()] = self.label_pad_token_id
|
||||||
|
|
||||||
|
# Prepare reference model hidden states if ref_model exists
|
||||||
|
use_ref_model = self.ref_model is not None and not self.reference_free
|
||||||
|
|
||||||
|
# Run main model forward to get hidden states
|
||||||
|
# If using FSDP, redirect forward calls
|
||||||
|
if isinstance(model, FullyShardedDataParallel):
|
||||||
|
outputs = _FSDPForwardRedirection()(
|
||||||
|
model,
|
||||||
|
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If model is a DataParallel, unwrap
|
||||||
|
if isinstance(model, torch.nn.DataParallel):
|
||||||
|
model = model.module
|
||||||
|
outputs = model.model(
|
||||||
|
input_ids, attention_mask=attention_mask, use_cache=False
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = outputs.last_hidden_state
|
||||||
|
|
||||||
|
ref_last_hidden_state = None
|
||||||
|
if use_ref_model:
|
||||||
|
ref_model = self.ref_model
|
||||||
|
if isinstance(ref_model, FullyShardedDataParallel):
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_outputs = _FSDPForwardRedirection()(
|
||||||
|
ref_model,
|
||||||
|
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if isinstance(ref_model, torch.nn.DataParallel):
|
||||||
|
ref_model = ref_model.module
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_outputs = ref_model.model(
|
||||||
|
input_ids, attention_mask=attention_mask, use_cache=False
|
||||||
|
)
|
||||||
|
ref_last_hidden_state = ref_outputs.last_hidden_state
|
||||||
|
|
||||||
|
# Retrieve lm_head parameters
|
||||||
|
lm_head = model.lm_head
|
||||||
|
ref_lm_head = (
|
||||||
|
self.ref_model.lm_head
|
||||||
|
if (use_ref_model and self.ref_model is not None)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use Liger fused DPO loss
|
||||||
|
dpo_loss_fn = LigerFusedLinearDPOLoss(
|
||||||
|
ignore_index=self.label_pad_token_id,
|
||||||
|
beta=self.beta,
|
||||||
|
compute_nll_loss=False,
|
||||||
|
compiled=True,
|
||||||
|
use_ref_model=use_ref_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# call fused Liger DPO
|
||||||
|
if use_ref_model:
|
||||||
|
loss_acc, aux_outputs = dpo_loss_fn(
|
||||||
|
lm_head.weight, # lin_weight
|
||||||
|
last_hidden_state, # _input
|
||||||
|
labels, # target
|
||||||
|
bias=lm_head.bias,
|
||||||
|
ref_input=ref_last_hidden_state,
|
||||||
|
ref_weight=ref_lm_head.weight,
|
||||||
|
ref_bias=ref_lm_head.bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits_mean,
|
||||||
|
policy_rejected_logits_mean,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = aux_outputs[:5]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No reference model scenario: Liger kernel treats ref_logps as 0
|
||||||
|
loss_acc, aux_outputs = dpo_loss_fn(
|
||||||
|
lm_head.weight,
|
||||||
|
last_hidden_state,
|
||||||
|
labels,
|
||||||
|
bias=lm_head.bias,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits_mean,
|
||||||
|
policy_rejected_logits_mean,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = aux_outputs[:5]
|
||||||
|
|
||||||
|
# Add aux loss if enabled
|
||||||
|
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
|
||||||
|
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
|
||||||
|
|
||||||
|
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
|
||||||
|
if self.args.rpo_alpha is not None:
|
||||||
|
# policy_nll_loss: average negative log-likelihood of chosen completions
|
||||||
|
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
|
||||||
|
|
||||||
|
return (
|
||||||
|
loss_acc,
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits_mean,
|
||||||
|
policy_rejected_logits_mean,
|
||||||
|
policy_nll_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
(
|
||||||
|
loss,
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits_mean,
|
||||||
|
policy_rejected_logits_mean,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
|
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
|
||||||
|
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
|
||||||
|
if self.ref_model is not None and not self.reference_free:
|
||||||
|
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
|
||||||
|
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
|
||||||
|
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
|
||||||
|
#
|
||||||
|
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
|
||||||
|
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
|
||||||
|
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
|
||||||
|
|
||||||
|
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
|
||||||
|
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
|
||||||
|
chosen_rewards = policy_chosen_logps * self.beta # approximation
|
||||||
|
rejected_rewards = policy_rejected_logps * self.beta # approximation
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
|
||||||
|
metrics[f"{train_eval}_rewards/rejected"] = (
|
||||||
|
rejected_rewards.mean().cpu().item()
|
||||||
|
)
|
||||||
|
metrics[f"{train_eval}_rewards/accuracies"] = (
|
||||||
|
reward_accuracies.mean().cpu().item()
|
||||||
|
)
|
||||||
|
metrics[f"{train_eval}_rewards/margins"] = (
|
||||||
|
(chosen_rewards - rejected_rewards).mean().cpu().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
|
||||||
|
metrics[f"{train_eval}_logps/rejected"] = (
|
||||||
|
policy_rejected_logps.mean().cpu().item()
|
||||||
|
)
|
||||||
|
metrics[f"{train_eval}_logits/chosen"] = (
|
||||||
|
policy_chosen_logits_mean.detach().cpu().item()
|
||||||
|
)
|
||||||
|
metrics[f"{train_eval}_logits/rejected"] = (
|
||||||
|
policy_rejected_logits_mean.detach().cpu().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.rpo_alpha is not None:
|
||||||
|
metrics[f"{train_eval}_nll_loss"] = (
|
||||||
|
policy_nll_loss.mean().detach().cpu().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss.mean(), metrics
|
||||||
@@ -8,17 +8,36 @@ def argilla(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
|
if "prompt" in sample.keys():
|
||||||
|
prompt_key = "prompt"
|
||||||
|
elif "input" in sample.keys():
|
||||||
|
prompt_key = "input"
|
||||||
|
elif "question" in sample.keys():
|
||||||
|
prompt_key = "question"
|
||||||
|
else:
|
||||||
|
prompt_key = "instruction"
|
||||||
|
|
||||||
|
if "chosen" in sample.keys():
|
||||||
|
chosen_key = "chosen"
|
||||||
|
else:
|
||||||
|
chosen_key = "chosen_response"
|
||||||
|
|
||||||
|
if "rejected" in sample.keys():
|
||||||
|
rejected_key = "rejected"
|
||||||
|
else:
|
||||||
|
rejected_key = "rejected_response"
|
||||||
|
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
sample["prompt"] = (
|
sample["prompt"] = (
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample[
|
||||||
"prompt"
|
"prompt"
|
||||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return transform_fn
|
return transform_fn
|
||||||
|
|||||||
@@ -8,17 +8,37 @@ def argilla(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if "prompt" in sample.keys():
|
||||||
|
prompt_key = "prompt"
|
||||||
|
elif "input" in sample.keys():
|
||||||
|
prompt_key = "input"
|
||||||
|
elif "question" in sample.keys():
|
||||||
|
prompt_key = "question"
|
||||||
|
else:
|
||||||
|
prompt_key = "instruction"
|
||||||
|
|
||||||
|
if "chosen" in sample.keys():
|
||||||
|
chosen_key = "chosen"
|
||||||
|
else:
|
||||||
|
chosen_key = "chosen_response"
|
||||||
|
|
||||||
|
if "rejected" in sample.keys():
|
||||||
|
rejected_key = "rejected"
|
||||||
|
else:
|
||||||
|
rejected_key = "rejected_response"
|
||||||
|
|
||||||
if "system" in sample and sample["system"]:
|
if "system" in sample and sample["system"]:
|
||||||
sample["prompt"] = (
|
sample["prompt"] = (
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample[
|
sample[
|
||||||
"prompt"
|
"prompt"
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return transform_fn
|
return transform_fn
|
||||||
|
|||||||
Reference in New Issue
Block a user