Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96af760e08 | ||
|
|
cfa80dace0 | ||
|
|
0a661980ca | ||
|
|
effc4dc409 | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 |
5
.github/workflows/tests-nightly.yml
vendored
5
.github/workflows/tests-nightly.yml
vendored
@@ -44,6 +44,11 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
tensor_parallel: 'auto'
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,73 +0,0 @@
|
||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
tensor_parallel: 'auto'
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -12,7 +12,7 @@ liger-kernel==0.4.2
|
||||
|
||||
packaging==23.2
|
||||
peft==0.14.0
|
||||
transformers>=4.46.3
|
||||
transformers==4.47.0
|
||||
tokenizers>=0.20.1
|
||||
accelerate==1.2.0
|
||||
datasets==3.1.0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1033
src/axolotl/core/trainers/base.py
Normal file
1033
src/axolotl/core/trainers/base.py
Normal file
File diff suppressed because it is too large
Load Diff
220
src/axolotl/core/training_args.py
Normal file
220
src/axolotl/core/training_args.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
"""
|
||||
CPO config for CPO training
|
||||
"""
|
||||
|
||||
simpo_gamma: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "simpo gamma parameter"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
@@ -36,6 +36,8 @@ class LigerArgs(BaseModel):
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
liger_pref_rl: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
|
||||
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
0
src/axolotl/integrations/liger/trainer/__init__.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
253
src/axolotl/integrations/liger/trainer/dpo_trainer.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
integration of liger dpo kernels with dpotrainer
|
||||
"""
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
|
||||
"""
|
||||
Extend the DPO Trainer to use LIGER kernels for DPO
|
||||
"""
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
):
|
||||
"""
|
||||
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
|
||||
and compute the DPO loss using Liger's fused kernel.
|
||||
|
||||
This method replaces the original `concatenated_forward` implementation to use Liger.
|
||||
"""
|
||||
|
||||
# Prepare concatenated inputs
|
||||
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
|
||||
|
||||
# Extract concatenated inputs
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
|
||||
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
|
||||
# This example assumes a causal decoder-only model.
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
|
||||
# Align inputs by removing leading padding
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
|
||||
|
||||
# Remove trailing empty columns
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
if empty_cols.any():
|
||||
first_empty_col = torch.nonzero(empty_cols)[0].item()
|
||||
input_ids = input_ids[:, :first_empty_col]
|
||||
attention_mask = attention_mask[:, :first_empty_col]
|
||||
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
|
||||
# Labels are completion_input_ids shifted by one token right
|
||||
# For causal LM, labels are the completion part only
|
||||
labels = torch.cat(
|
||||
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
|
||||
)
|
||||
labels = labels[:, 1:] # shift left by one
|
||||
attention_mask = attention_mask[:, 1:]
|
||||
labels = labels[:, : attention_mask.size(1)]
|
||||
|
||||
# Mask out the prompt portion from loss
|
||||
labels[~attention_mask.bool()] = self.label_pad_token_id
|
||||
|
||||
# Prepare reference model hidden states if ref_model exists
|
||||
use_ref_model = self.ref_model is not None and not self.reference_free
|
||||
|
||||
# Run main model forward to get hidden states
|
||||
# If using FSDP, redirect forward calls
|
||||
if isinstance(model, FullyShardedDataParallel):
|
||||
outputs = _FSDPForwardRedirection()(
|
||||
model,
|
||||
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
# If model is a DataParallel, unwrap
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
outputs = model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
ref_last_hidden_state = None
|
||||
if use_ref_model:
|
||||
ref_model = self.ref_model
|
||||
if isinstance(ref_model, FullyShardedDataParallel):
|
||||
with torch.no_grad():
|
||||
ref_outputs = _FSDPForwardRedirection()(
|
||||
ref_model,
|
||||
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
if isinstance(ref_model, torch.nn.DataParallel):
|
||||
ref_model = ref_model.module
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
ref_last_hidden_state = ref_outputs.last_hidden_state
|
||||
|
||||
# Retrieve lm_head parameters
|
||||
lm_head = model.lm_head
|
||||
ref_lm_head = (
|
||||
self.ref_model.lm_head
|
||||
if (use_ref_model and self.ref_model is not None)
|
||||
else None
|
||||
)
|
||||
|
||||
# Use Liger fused DPO loss
|
||||
dpo_loss_fn = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=False,
|
||||
compiled=True,
|
||||
use_ref_model=use_ref_model,
|
||||
)
|
||||
|
||||
# call fused Liger DPO
|
||||
if use_ref_model:
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight, # lin_weight
|
||||
last_hidden_state, # _input
|
||||
labels, # target
|
||||
bias=lm_head.bias,
|
||||
ref_input=ref_last_hidden_state,
|
||||
ref_weight=ref_lm_head.weight,
|
||||
ref_bias=ref_lm_head.bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
else:
|
||||
# No reference model scenario: Liger kernel treats ref_logps as 0
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight,
|
||||
last_hidden_state,
|
||||
labels,
|
||||
bias=lm_head.bias,
|
||||
)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
# Add aux loss if enabled
|
||||
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
|
||||
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
|
||||
|
||||
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
|
||||
if self.args.rpo_alpha is not None:
|
||||
# policy_nll_loss: average negative log-likelihood of chosen completions
|
||||
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
|
||||
|
||||
return (
|
||||
loss_acc,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""
|
||||
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
loss,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
|
||||
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
|
||||
if self.ref_model is not None and not self.reference_free:
|
||||
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
|
||||
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
|
||||
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
|
||||
#
|
||||
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
|
||||
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
|
||||
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
|
||||
|
||||
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
|
||||
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
|
||||
chosen_rewards = policy_chosen_logps * self.beta # approximation
|
||||
rejected_rewards = policy_rejected_logps * self.beta # approximation
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
|
||||
metrics[f"{train_eval}_rewards/rejected"] = (
|
||||
rejected_rewards.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/accuracies"] = (
|
||||
reward_accuracies.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/margins"] = (
|
||||
(chosen_rewards - rejected_rewards).mean().cpu().item()
|
||||
)
|
||||
|
||||
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
|
||||
metrics[f"{train_eval}_logps/rejected"] = (
|
||||
policy_rejected_logps.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/chosen"] = (
|
||||
policy_chosen_logits_mean.detach().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/rejected"] = (
|
||||
policy_rejected_logits_mean.detach().cpu().item()
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
metrics[f"{train_eval}_nll_loss"] = (
|
||||
policy_nll_loss.mean().detach().cpu().item()
|
||||
)
|
||||
|
||||
return loss.mean(), metrics
|
||||
@@ -8,17 +8,36 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -8,17 +8,37 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
# pylint: disable=duplicate-code
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: Optional[bool] = None
|
||||
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
|
||||
@@ -1187,15 +1187,9 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.post_loading_set_env()
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return self.model, lora_config
|
||||
|
||||
def post_loading_set_env(self):
|
||||
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
|
||||
Reference in New Issue
Block a user