Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
8428b3f2c7 feat: add dpo liger 2024-12-16 22:19:27 +07:00
Wing Lian
02629c7cdf parity for nightly ci - make sure to install setuptools (#2176) [skip ci] 2024-12-11 20:14:55 -05:00
Wing Lian
78a4aa86d6 evaluation_strategy was fully deprecated in recent release (#2169) [skip ci] 2024-12-11 20:14:24 -05:00
7 changed files with 333 additions and 147 deletions

View File

@@ -44,6 +44,11 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu

View File

@@ -1,58 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,73 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -14,17 +14,22 @@ import os
import sys
from abc import abstractmethod
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import torch.nn.functional as F
import transformers
from datasets import Dataset
from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch import amp, nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -1077,6 +1082,15 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.dataset_tags = dataset_tags
self.optimizer = None
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
self.liger_loss = LigerFusedLinearDPOLoss(
ignore_index=self.label_pad_token_id,
beta=self.beta,
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
use_ref_model=not self.reference_free,
)
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
@@ -1180,6 +1194,309 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics using Liger kernel."""
# return super().get_batch_loss_metrics(model, batch, train_eval)
if not self.liger_loss:
raise ValueError("Liger loss not initialized")
metrics = {}
model_output = self.concatenated_forward(model, batch)
# Get the lm_head weights and bias
lin_weight = model.lm_head.weight
lin_bias = getattr(model.lm_head, "bias", None)
hidden_states = model_output["hidden_states"]
labels = model_output["labels"]
if not self.reference_free:
# Adapted from DPO's compute_ref_log_probs
compte_ref_context_manager = (
amp.autocast("cuda")
if self._peft_has_been_casted_to_bf16
else nullcontext()
)
with torch.no_grad(), compte_ref_context_manager: # type: ignore
if self.ref_model is None:
with self.null_ref_context():
ref_model_output = self.concatenated_forward(self.model, batch)
ref_weight = self.model.lm_head.weight
ref_bias = getattr(self.model.lm_head, "bias", None)
ref_hidden_states = ref_model_output["hidden_states"]
else:
ref_model_output = self.concatenated_forward(self.ref_model, batch)
ref_weight = self.ref_model.lm_head.weight
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
ref_hidden_states = ref_model_output["hidden_states"]
(
ref_chosen_logps,
ref_rejected_logps,
_ref_chosen_logits,
_ref_rejected_logits,
_ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk=ref_hidden_states,
weight=ref_weight,
target_chunk=labels,
bias=ref_bias,
# ignore_index=ignore_index,
compute_nll_loss=False,
)
else:
ref_hidden_states = None
ref_weight = None
ref_bias = None
# Compute loss using Liger kernel
loss, return_vars = self.liger_loss(
lin_weight=lin_weight,
_input=hidden_states,
target=labels,
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
ref_input=ref_hidden_states,
ref_weight=ref_weight,
ref_bias=ref_bias,
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
) = return_vars
# Calculate rewards
if not self.reference_free:
chosen_rewards = (
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
)
rejected_rewards = (
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
)
else:
chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics.update(
{
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
}
)
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
return loss, metrics
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
Overridden base function to return the hidden states and labels for the loss calculation.
"""
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
concatenated_batch = self.concatenated_inputs(
batch, padding_value=self.padding_value
)
model_kwargs = {}
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
# Add to get the hidden states for the loss
model_kwargs["output_hidden_states"] = True
# Add the pixel values and attention masks for vision models
if "pixel_values" in concatenated_batch:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch[
"pixel_attention_mask"
]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_input_ids = concatenated_batch["completion_input_ids"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]
if self.is_encoder_decoder:
labels = completion_input_ids
labels[completion_attention_mask == 0] = self.label_pad_token_id
outputs = model(
input_ids=prompt_input_ids,
attention_mask=prompt_attention_mask,
labels=labels, # we need the labels for the logits to be returned
**model_kwargs,
)
logits = outputs.logits
hidden_states = outputs.decoder_hidden_states[-1]
loss_mask = completion_attention_mask.bool()
else:
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat(
(prompt_attention_mask, completion_attention_mask), dim=1
)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
dim=1,
)
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = (
torch.nonzero(empty_cols)[0].item()
if empty_cols.any()
else attention_mask.size(1)
)
input_ids = input_ids[:, :first_empty_col] # type: ignore
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]
# if self.use_num_logits_to_keep:
# # Compute num_logits_to_keep based on loss_mask pattern:
# # [[0, 0, 0, x, x, x, x],
# # [0, 0, 0, x, x, x, 0]]
# # ^ start computing logits from here ([:, -(7-3+1):])
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
outputs = model(
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
)
# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
hidden_states = outputs.hidden_states[-1][:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()
# if self.use_num_logits_to_keep:
# # Align labels with logits
# # logits: -, -, [x2, x3, x4, x5, x6]
# # ^ --------- ^ after logits[:, :-1, :]
# # labels: [y0, y1, y2, y3, y4, y5, y6]
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
# labels = labels[:, -num_logits_to_keep:]
# loss_mask = loss_mask[:, -num_logits_to_keep:]
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]
hidden_states = hidden_states[:, -seq_len:]
# Compute the log probabilities of the labels
labels[
~loss_mask
] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)
output = {}
if self.use_weighting:
with torch.no_grad():
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(
2 * logprobs, dim=-1
) # same as sum(probs**2) in log space
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(
-1
) / loss_mask.sum(-1)
chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(
torch.exp(chosen_weights + rejected_weights), max=1
)
if self.args.rpo_alpha is not None:
# Only use the chosen logits for the RPO loss
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
# Compute the log probabilities of the labels
output["nll_loss"] = F.cross_entropy(
torch.flatten(chosen_logits, end_dim=1),
torch.flatten(chosen_labels, end_dim=1),
ignore_index=0,
)
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
output["mean_chosen_logits"] = logits[:num_examples][
loss_mask[:num_examples]
].mean()
output["mean_rejected_logits"] = logits[num_examples:][
loss_mask[num_examples:]
].mean()
output["hidden_states"] = hidden_states
output["labels"] = labels
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
return output
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1319,10 +1636,6 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property
def model_ref(self):
return self._model_ref
@@ -2167,6 +2480,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["report_to"] = report_to
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,

View File

@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")

View File

@@ -1187,15 +1187,9 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling
return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model(
cfg: DictDefault,