Compare commits
3 Commits
enable_tp
...
feat/pref_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8428b3f2c7 | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 |
5
.github/workflows/tests-nightly.yml
vendored
5
.github/workflows/tests-nightly.yml
vendored
@@ -44,6 +44,11 @@ jobs:
|
|||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
|
- name: upgrade pip
|
||||||
|
run: |
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install --upgrade packaging setuptools wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
tensor_parallel: 'auto'
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
evals_per_epoch: 2
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
base_model: NousResearch/Meta-Llama-3.1-8B
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_modules_to_save:
|
|
||||||
- embed_tokens
|
|
||||||
- lm_head
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
tensor_parallel: 'auto'
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
s2_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
@@ -14,17 +14,22 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from liger_kernel.chunked_loss.fused_linear_preference import (
|
||||||
|
LigerFusedLinearPreferenceBase,
|
||||||
|
)
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import amp, nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -1077,6 +1082,15 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
|
||||||
|
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||||
|
|
||||||
|
self.liger_loss = LigerFusedLinearDPOLoss(
|
||||||
|
ignore_index=self.label_pad_token_id,
|
||||||
|
beta=self.beta,
|
||||||
|
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
|
||||||
|
use_ref_model=not self.reference_free,
|
||||||
|
)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
if self.args.loraplus_lr_ratio is None:
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
@@ -1180,6 +1194,309 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
# transformers<=4.46
|
# transformers<=4.46
|
||||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the DPO loss and other metrics using Liger kernel."""
|
||||||
|
# return super().get_batch_loss_metrics(model, batch, train_eval)
|
||||||
|
if not self.liger_loss:
|
||||||
|
raise ValueError("Liger loss not initialized")
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
model_output = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
|
# Get the lm_head weights and bias
|
||||||
|
lin_weight = model.lm_head.weight
|
||||||
|
lin_bias = getattr(model.lm_head, "bias", None)
|
||||||
|
|
||||||
|
hidden_states = model_output["hidden_states"]
|
||||||
|
labels = model_output["labels"]
|
||||||
|
|
||||||
|
if not self.reference_free:
|
||||||
|
# Adapted from DPO's compute_ref_log_probs
|
||||||
|
compte_ref_context_manager = (
|
||||||
|
amp.autocast("cuda")
|
||||||
|
if self._peft_has_been_casted_to_bf16
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with torch.no_grad(), compte_ref_context_manager: # type: ignore
|
||||||
|
if self.ref_model is None:
|
||||||
|
with self.null_ref_context():
|
||||||
|
ref_model_output = self.concatenated_forward(self.model, batch)
|
||||||
|
ref_weight = self.model.lm_head.weight
|
||||||
|
ref_bias = getattr(self.model.lm_head, "bias", None)
|
||||||
|
|
||||||
|
ref_hidden_states = ref_model_output["hidden_states"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
||||||
|
ref_weight = self.ref_model.lm_head.weight
|
||||||
|
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
|
||||||
|
|
||||||
|
ref_hidden_states = ref_model_output["hidden_states"]
|
||||||
|
(
|
||||||
|
ref_chosen_logps,
|
||||||
|
ref_rejected_logps,
|
||||||
|
_ref_chosen_logits,
|
||||||
|
_ref_rejected_logits,
|
||||||
|
_ref_chosen_nll_loss,
|
||||||
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
||||||
|
input_chunk=ref_hidden_states,
|
||||||
|
weight=ref_weight,
|
||||||
|
target_chunk=labels,
|
||||||
|
bias=ref_bias,
|
||||||
|
# ignore_index=ignore_index,
|
||||||
|
compute_nll_loss=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
ref_hidden_states = None
|
||||||
|
ref_weight = None
|
||||||
|
ref_bias = None
|
||||||
|
|
||||||
|
# Compute loss using Liger kernel
|
||||||
|
loss, return_vars = self.liger_loss(
|
||||||
|
lin_weight=lin_weight,
|
||||||
|
_input=hidden_states,
|
||||||
|
target=labels,
|
||||||
|
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
|
||||||
|
ref_input=ref_hidden_states,
|
||||||
|
ref_weight=ref_weight,
|
||||||
|
ref_bias=ref_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits_mean,
|
||||||
|
policy_rejected_logits_mean,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = return_vars
|
||||||
|
|
||||||
|
# Calculate rewards
|
||||||
|
if not self.reference_free:
|
||||||
|
chosen_rewards = (
|
||||||
|
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
|
||||||
|
)
|
||||||
|
rejected_rewards = (
|
||||||
|
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
chosen_rewards = self.beta * policy_chosen_logps
|
||||||
|
rejected_rewards = self.beta * policy_rejected_logps
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics.update(
|
||||||
|
{
|
||||||
|
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
|
||||||
|
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
|
||||||
|
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
|
||||||
|
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
|
||||||
|
.mean()
|
||||||
|
.cpu(),
|
||||||
|
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
|
||||||
|
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
|
||||||
|
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
|
||||||
|
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
|
||||||
|
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
|
||||||
|
|
||||||
|
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||||
|
):
|
||||||
|
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||||
|
|
||||||
|
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||||
|
|
||||||
|
Overridden base function to return the hidden states and labels for the loss calculation.
|
||||||
|
"""
|
||||||
|
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
|
||||||
|
|
||||||
|
concatenated_batch = self.concatenated_inputs(
|
||||||
|
batch, padding_value=self.padding_value
|
||||||
|
)
|
||||||
|
|
||||||
|
model_kwargs = {}
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
model_kwargs["output_router_logits"] = True
|
||||||
|
|
||||||
|
# Add to get the hidden states for the loss
|
||||||
|
model_kwargs["output_hidden_states"] = True
|
||||||
|
|
||||||
|
# Add the pixel values and attention masks for vision models
|
||||||
|
if "pixel_values" in concatenated_batch:
|
||||||
|
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
||||||
|
if "pixel_attention_mask" in concatenated_batch:
|
||||||
|
model_kwargs["pixel_attention_mask"] = concatenated_batch[
|
||||||
|
"pixel_attention_mask"
|
||||||
|
]
|
||||||
|
if "image_sizes" in concatenated_batch:
|
||||||
|
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
|
||||||
|
|
||||||
|
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||||
|
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||||
|
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||||
|
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
labels = completion_input_ids
|
||||||
|
labels[completion_attention_mask == 0] = self.label_pad_token_id
|
||||||
|
outputs = model(
|
||||||
|
input_ids=prompt_input_ids,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
labels=labels, # we need the labels for the logits to be returned
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
logits = outputs.logits
|
||||||
|
hidden_states = outputs.decoder_hidden_states[-1]
|
||||||
|
loss_mask = completion_attention_mask.bool()
|
||||||
|
else:
|
||||||
|
# Concatenate the prompt and completion inputs
|
||||||
|
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||||
|
)
|
||||||
|
# Mask the prompt but not the completion for the loss
|
||||||
|
loss_mask = torch.cat(
|
||||||
|
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flush left to reduce the memory usage
|
||||||
|
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||||
|
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||||
|
for i in range(attention_mask.size(0)):
|
||||||
|
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||||
|
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
|
||||||
|
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
|
||||||
|
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
|
||||||
|
|
||||||
|
# Get the first column idx that is all zeros and remove every column after that
|
||||||
|
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||||
|
first_empty_col = (
|
||||||
|
torch.nonzero(empty_cols)[0].item()
|
||||||
|
if empty_cols.any()
|
||||||
|
else attention_mask.size(1)
|
||||||
|
)
|
||||||
|
input_ids = input_ids[:, :first_empty_col] # type: ignore
|
||||||
|
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
|
||||||
|
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
|
||||||
|
|
||||||
|
# Truncate right
|
||||||
|
if self.args.max_length is not None:
|
||||||
|
input_ids = input_ids[:, : self.args.max_length]
|
||||||
|
attention_mask = attention_mask[:, : self.args.max_length]
|
||||||
|
loss_mask = loss_mask[:, : self.args.max_length]
|
||||||
|
|
||||||
|
# if self.use_num_logits_to_keep:
|
||||||
|
# # Compute num_logits_to_keep based on loss_mask pattern:
|
||||||
|
# # [[0, 0, 0, x, x, x, x],
|
||||||
|
# # [0, 0, 0, x, x, x, 0]]
|
||||||
|
# # ^ start computing logits from here ([:, -(7-3+1):])
|
||||||
|
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
|
||||||
|
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
|
||||||
|
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
|
||||||
|
|
||||||
|
outputs = model(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Offset the logits by one to align with the labels
|
||||||
|
logits = outputs.logits[:, :-1, :]
|
||||||
|
hidden_states = outputs.hidden_states[-1][:, :-1, :]
|
||||||
|
labels = input_ids[:, 1:].clone()
|
||||||
|
loss_mask = loss_mask[:, 1:].bool()
|
||||||
|
|
||||||
|
# if self.use_num_logits_to_keep:
|
||||||
|
# # Align labels with logits
|
||||||
|
# # logits: -, -, [x2, x3, x4, x5, x6]
|
||||||
|
# # ^ --------- ^ after logits[:, :-1, :]
|
||||||
|
# # labels: [y0, y1, y2, y3, y4, y5, y6]
|
||||||
|
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
|
||||||
|
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
|
||||||
|
# labels = labels[:, -num_logits_to_keep:]
|
||||||
|
# loss_mask = loss_mask[:, -num_logits_to_keep:]
|
||||||
|
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||||
|
|
||||||
|
if logits.shape[:2] != labels.shape[:2]:
|
||||||
|
# for llava, the returned logits include the image tokens (placed before the text tokens)
|
||||||
|
seq_len = labels.shape[1]
|
||||||
|
logits = logits[:, -seq_len:]
|
||||||
|
hidden_states = hidden_states[:, -seq_len:]
|
||||||
|
|
||||||
|
# Compute the log probabilities of the labels
|
||||||
|
labels[
|
||||||
|
~loss_mask
|
||||||
|
] = 0 # dummy token; we'll ignore the losses on these tokens later
|
||||||
|
per_token_logps = torch.gather(
|
||||||
|
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
||||||
|
).squeeze(2)
|
||||||
|
per_token_logps[~loss_mask] = 0
|
||||||
|
all_logps = per_token_logps.sum(-1)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
if self.use_weighting:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
|
||||||
|
logprobs = F.log_softmax(logits, dim=-1)
|
||||||
|
weights_adjustment_factor = torch.logsumexp(
|
||||||
|
2 * logprobs, dim=-1
|
||||||
|
) # same as sum(probs**2) in log space
|
||||||
|
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
|
||||||
|
all_weights = (per_token_logps_adjusted * loss_mask).sum(
|
||||||
|
-1
|
||||||
|
) / loss_mask.sum(-1)
|
||||||
|
chosen_weights = all_weights[:num_examples]
|
||||||
|
rejected_weights = all_weights[num_examples:]
|
||||||
|
output["policy_weights"] = torch.clamp(
|
||||||
|
torch.exp(chosen_weights + rejected_weights), max=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.rpo_alpha is not None:
|
||||||
|
# Only use the chosen logits for the RPO loss
|
||||||
|
chosen_logits = logits[:num_examples]
|
||||||
|
chosen_labels = labels[:num_examples]
|
||||||
|
|
||||||
|
# Compute the log probabilities of the labels
|
||||||
|
output["nll_loss"] = F.cross_entropy(
|
||||||
|
torch.flatten(chosen_logits, end_dim=1),
|
||||||
|
torch.flatten(chosen_labels, end_dim=1),
|
||||||
|
ignore_index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.loss_type == "ipo":
|
||||||
|
all_logps = all_logps / loss_mask.sum(-1)
|
||||||
|
|
||||||
|
output["chosen_logps"] = all_logps[:num_examples]
|
||||||
|
output["rejected_logps"] = all_logps[num_examples:]
|
||||||
|
output["mean_chosen_logits"] = logits[:num_examples][
|
||||||
|
loss_mask[:num_examples]
|
||||||
|
].mean()
|
||||||
|
output["mean_rejected_logits"] = logits[num_examples:][
|
||||||
|
loss_mask[num_examples:]
|
||||||
|
].mean()
|
||||||
|
output["hidden_states"] = hidden_states
|
||||||
|
output["labels"] = labels
|
||||||
|
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
output["aux_loss"] = outputs.aux_loss
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1319,10 +1636,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
|
||||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
|
||||||
# self.model =
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_ref(self):
|
def model_ref(self):
|
||||||
return self._model_ref
|
return self._model_ref
|
||||||
@@ -2167,6 +2480,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
|
report_to = []
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
report_to.append("wandb")
|
||||||
|
if self.cfg.wandb_name:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
|
training_args_kwargs["report_to"] = report_to
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
|
|||||||
@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (
|
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
|
||||||
and state.global_step == 1
|
|
||||||
):
|
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
|
|||||||
@@ -1187,15 +1187,9 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.post_loading_set_env()
|
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
def post_loading_set_env(self):
|
|
||||||
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
|
||||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
Reference in New Issue
Block a user