Compare commits
3 Commits
feat/beaut
...
liger-dpo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96af760e08 | ||
|
|
cfa80dace0 | ||
|
|
0a661980ca |
File diff suppressed because it is too large
Load Diff
1033
src/axolotl/core/trainers/base.py
Normal file
1033
src/axolotl/core/trainers/base.py
Normal file
File diff suppressed because it is too large
Load Diff
220
src/axolotl/core/training_args.py
Normal file
220
src/axolotl/core/training_args.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
"""
|
||||
CPO config for CPO training
|
||||
"""
|
||||
|
||||
simpo_gamma: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "simpo gamma parameter"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
@@ -36,6 +36,8 @@ class LigerArgs(BaseModel):
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
liger_pref_rl: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
|
||||
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
integration of liger dpo kernels with dpotrainer
|
||||
"""
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
|
||||
"""
|
||||
Extend the DPO Trainer to use LIGER kernels for DPO
|
||||
"""
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
):
|
||||
"""
|
||||
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
|
||||
and compute the DPO loss using Liger's fused kernel.
|
||||
|
||||
This method replaces the original `concatenated_forward` implementation to use Liger.
|
||||
"""
|
||||
|
||||
# Prepare concatenated inputs
|
||||
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
|
||||
|
||||
# Extract concatenated inputs
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
|
||||
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
|
||||
# This example assumes a causal decoder-only model.
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
|
||||
# Align inputs by removing leading padding
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
|
||||
|
||||
# Remove trailing empty columns
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
if empty_cols.any():
|
||||
first_empty_col = torch.nonzero(empty_cols)[0].item()
|
||||
input_ids = input_ids[:, :first_empty_col]
|
||||
attention_mask = attention_mask[:, :first_empty_col]
|
||||
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
|
||||
# Labels are completion_input_ids shifted by one token right
|
||||
# For causal LM, labels are the completion part only
|
||||
labels = torch.cat(
|
||||
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
|
||||
)
|
||||
labels = labels[:, 1:] # shift left by one
|
||||
attention_mask = attention_mask[:, 1:]
|
||||
labels = labels[:, : attention_mask.size(1)]
|
||||
|
||||
# Mask out the prompt portion from loss
|
||||
labels[~attention_mask.bool()] = self.label_pad_token_id
|
||||
|
||||
# Prepare reference model hidden states if ref_model exists
|
||||
use_ref_model = self.ref_model is not None and not self.reference_free
|
||||
|
||||
# Run main model forward to get hidden states
|
||||
# If using FSDP, redirect forward calls
|
||||
if isinstance(model, FullyShardedDataParallel):
|
||||
outputs = _FSDPForwardRedirection()(
|
||||
model,
|
||||
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
# If model is a DataParallel, unwrap
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
outputs = model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
ref_last_hidden_state = None
|
||||
if use_ref_model:
|
||||
ref_model = self.ref_model
|
||||
if isinstance(ref_model, FullyShardedDataParallel):
|
||||
with torch.no_grad():
|
||||
ref_outputs = _FSDPForwardRedirection()(
|
||||
ref_model,
|
||||
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
if isinstance(ref_model, torch.nn.DataParallel):
|
||||
ref_model = ref_model.module
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
ref_last_hidden_state = ref_outputs.last_hidden_state
|
||||
|
||||
# Retrieve lm_head parameters
|
||||
lm_head = model.lm_head
|
||||
ref_lm_head = (
|
||||
self.ref_model.lm_head
|
||||
if (use_ref_model and self.ref_model is not None)
|
||||
else None
|
||||
)
|
||||
|
||||
# Use Liger fused DPO loss
|
||||
dpo_loss_fn = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=False,
|
||||
compiled=True,
|
||||
use_ref_model=use_ref_model,
|
||||
)
|
||||
|
||||
# call fused Liger DPO
|
||||
if use_ref_model:
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight, # lin_weight
|
||||
last_hidden_state, # _input
|
||||
labels, # target
|
||||
bias=lm_head.bias,
|
||||
ref_input=ref_last_hidden_state,
|
||||
ref_weight=ref_lm_head.weight,
|
||||
ref_bias=ref_lm_head.bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
else:
|
||||
# No reference model scenario: Liger kernel treats ref_logps as 0
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight,
|
||||
last_hidden_state,
|
||||
labels,
|
||||
bias=lm_head.bias,
|
||||
)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
# Add aux loss if enabled
|
||||
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
|
||||
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
|
||||
|
||||
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
|
||||
if self.args.rpo_alpha is not None:
|
||||
# policy_nll_loss: average negative log-likelihood of chosen completions
|
||||
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
|
||||
|
||||
return (
|
||||
loss_acc,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""
|
||||
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
loss,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
|
||||
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
|
||||
if self.ref_model is not None and not self.reference_free:
|
||||
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
|
||||
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
|
||||
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
|
||||
#
|
||||
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
|
||||
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
|
||||
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
|
||||
|
||||
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
|
||||
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
|
||||
chosen_rewards = policy_chosen_logps * self.beta # approximation
|
||||
rejected_rewards = policy_rejected_logps * self.beta # approximation
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
|
||||
metrics[f"{train_eval}_rewards/rejected"] = (
|
||||
rejected_rewards.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/accuracies"] = (
|
||||
reward_accuracies.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/margins"] = (
|
||||
(chosen_rewards - rejected_rewards).mean().cpu().item()
|
||||
)
|
||||
|
||||
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
|
||||
metrics[f"{train_eval}_logps/rejected"] = (
|
||||
policy_rejected_logps.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/chosen"] = (
|
||||
policy_chosen_logits_mean.detach().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/rejected"] = (
|
||||
policy_rejected_logits_mean.detach().cpu().item()
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
metrics[f"{train_eval}_nll_loss"] = (
|
||||
policy_nll_loss.mean().detach().cpu().item()
|
||||
)
|
||||
|
||||
return loss.mean(), metrics
|
||||
@@ -8,17 +8,36 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -8,17 +8,37 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
# pylint: disable=duplicate-code
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
Reference in New Issue
Block a user