Compare commits
5 Commits
feat/pref_
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed |
@@ -14,22 +14,17 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
||||||
LigerFusedLinearPreferenceBase,
|
|
||||||
)
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import amp, nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -1001,6 +996,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def _evaluate(self, *args, **kwargs):
|
||||||
|
metrics = super()._evaluate(*args, **kwargs)
|
||||||
|
|
||||||
|
# cleanup memory after evals
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1082,15 +1086,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
|
||||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
|
||||||
|
|
||||||
self.liger_loss = LigerFusedLinearDPOLoss(
|
|
||||||
ignore_index=self.label_pad_token_id,
|
|
||||||
beta=self.beta,
|
|
||||||
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
|
|
||||||
use_ref_model=not self.reference_free,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
if self.args.loraplus_lr_ratio is None:
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
@@ -1194,309 +1189,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
# transformers<=4.46
|
# transformers<=4.46
|
||||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""Compute the DPO loss and other metrics using Liger kernel."""
|
|
||||||
# return super().get_batch_loss_metrics(model, batch, train_eval)
|
|
||||||
if not self.liger_loss:
|
|
||||||
raise ValueError("Liger loss not initialized")
|
|
||||||
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
model_output = self.concatenated_forward(model, batch)
|
|
||||||
|
|
||||||
# Get the lm_head weights and bias
|
|
||||||
lin_weight = model.lm_head.weight
|
|
||||||
lin_bias = getattr(model.lm_head, "bias", None)
|
|
||||||
|
|
||||||
hidden_states = model_output["hidden_states"]
|
|
||||||
labels = model_output["labels"]
|
|
||||||
|
|
||||||
if not self.reference_free:
|
|
||||||
# Adapted from DPO's compute_ref_log_probs
|
|
||||||
compte_ref_context_manager = (
|
|
||||||
amp.autocast("cuda")
|
|
||||||
if self._peft_has_been_casted_to_bf16
|
|
||||||
else nullcontext()
|
|
||||||
)
|
|
||||||
with torch.no_grad(), compte_ref_context_manager: # type: ignore
|
|
||||||
if self.ref_model is None:
|
|
||||||
with self.null_ref_context():
|
|
||||||
ref_model_output = self.concatenated_forward(self.model, batch)
|
|
||||||
ref_weight = self.model.lm_head.weight
|
|
||||||
ref_bias = getattr(self.model.lm_head, "bias", None)
|
|
||||||
|
|
||||||
ref_hidden_states = ref_model_output["hidden_states"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
|
||||||
ref_weight = self.ref_model.lm_head.weight
|
|
||||||
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
|
|
||||||
|
|
||||||
ref_hidden_states = ref_model_output["hidden_states"]
|
|
||||||
(
|
|
||||||
ref_chosen_logps,
|
|
||||||
ref_rejected_logps,
|
|
||||||
_ref_chosen_logits,
|
|
||||||
_ref_rejected_logits,
|
|
||||||
_ref_chosen_nll_loss,
|
|
||||||
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
|
||||||
input_chunk=ref_hidden_states,
|
|
||||||
weight=ref_weight,
|
|
||||||
target_chunk=labels,
|
|
||||||
bias=ref_bias,
|
|
||||||
# ignore_index=ignore_index,
|
|
||||||
compute_nll_loss=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
ref_hidden_states = None
|
|
||||||
ref_weight = None
|
|
||||||
ref_bias = None
|
|
||||||
|
|
||||||
# Compute loss using Liger kernel
|
|
||||||
loss, return_vars = self.liger_loss(
|
|
||||||
lin_weight=lin_weight,
|
|
||||||
_input=hidden_states,
|
|
||||||
target=labels,
|
|
||||||
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
|
|
||||||
ref_input=ref_hidden_states,
|
|
||||||
ref_weight=ref_weight,
|
|
||||||
ref_bias=ref_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits_mean,
|
|
||||||
policy_rejected_logits_mean,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = return_vars
|
|
||||||
|
|
||||||
# Calculate rewards
|
|
||||||
if not self.reference_free:
|
|
||||||
chosen_rewards = (
|
|
||||||
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
|
|
||||||
)
|
|
||||||
rejected_rewards = (
|
|
||||||
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
chosen_rewards = self.beta * policy_chosen_logps
|
|
||||||
rejected_rewards = self.beta * policy_rejected_logps
|
|
||||||
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
metrics.update(
|
|
||||||
{
|
|
||||||
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
|
|
||||||
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
|
|
||||||
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
|
|
||||||
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
|
|
||||||
.mean()
|
|
||||||
.cpu(),
|
|
||||||
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
|
|
||||||
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
|
|
||||||
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
|
|
||||||
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
|
|
||||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
|
|
||||||
|
|
||||||
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
|
|
||||||
|
|
||||||
return loss, metrics
|
|
||||||
|
|
||||||
def concatenated_forward(
|
|
||||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
|
||||||
):
|
|
||||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
|
||||||
|
|
||||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
|
||||||
|
|
||||||
Overridden base function to return the hidden states and labels for the loss calculation.
|
|
||||||
"""
|
|
||||||
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
|
|
||||||
|
|
||||||
concatenated_batch = self.concatenated_inputs(
|
|
||||||
batch, padding_value=self.padding_value
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs = {}
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
model_kwargs["output_router_logits"] = True
|
|
||||||
|
|
||||||
# Add to get the hidden states for the loss
|
|
||||||
model_kwargs["output_hidden_states"] = True
|
|
||||||
|
|
||||||
# Add the pixel values and attention masks for vision models
|
|
||||||
if "pixel_values" in concatenated_batch:
|
|
||||||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
|
||||||
if "pixel_attention_mask" in concatenated_batch:
|
|
||||||
model_kwargs["pixel_attention_mask"] = concatenated_batch[
|
|
||||||
"pixel_attention_mask"
|
|
||||||
]
|
|
||||||
if "image_sizes" in concatenated_batch:
|
|
||||||
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
|
|
||||||
|
|
||||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
|
||||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
|
||||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
|
||||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
labels = completion_input_ids
|
|
||||||
labels[completion_attention_mask == 0] = self.label_pad_token_id
|
|
||||||
outputs = model(
|
|
||||||
input_ids=prompt_input_ids,
|
|
||||||
attention_mask=prompt_attention_mask,
|
|
||||||
labels=labels, # we need the labels for the logits to be returned
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
logits = outputs.logits
|
|
||||||
hidden_states = outputs.decoder_hidden_states[-1]
|
|
||||||
loss_mask = completion_attention_mask.bool()
|
|
||||||
else:
|
|
||||||
# Concatenate the prompt and completion inputs
|
|
||||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
|
||||||
)
|
|
||||||
# Mask the prompt but not the completion for the loss
|
|
||||||
loss_mask = torch.cat(
|
|
||||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flush left to reduce the memory usage
|
|
||||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
|
||||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
|
||||||
for i in range(attention_mask.size(0)):
|
|
||||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
|
||||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
|
|
||||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
|
|
||||||
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
|
|
||||||
|
|
||||||
# Get the first column idx that is all zeros and remove every column after that
|
|
||||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
|
||||||
first_empty_col = (
|
|
||||||
torch.nonzero(empty_cols)[0].item()
|
|
||||||
if empty_cols.any()
|
|
||||||
else attention_mask.size(1)
|
|
||||||
)
|
|
||||||
input_ids = input_ids[:, :first_empty_col] # type: ignore
|
|
||||||
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
|
|
||||||
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
|
|
||||||
|
|
||||||
# Truncate right
|
|
||||||
if self.args.max_length is not None:
|
|
||||||
input_ids = input_ids[:, : self.args.max_length]
|
|
||||||
attention_mask = attention_mask[:, : self.args.max_length]
|
|
||||||
loss_mask = loss_mask[:, : self.args.max_length]
|
|
||||||
|
|
||||||
# if self.use_num_logits_to_keep:
|
|
||||||
# # Compute num_logits_to_keep based on loss_mask pattern:
|
|
||||||
# # [[0, 0, 0, x, x, x, x],
|
|
||||||
# # [0, 0, 0, x, x, x, 0]]
|
|
||||||
# # ^ start computing logits from here ([:, -(7-3+1):])
|
|
||||||
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
|
|
||||||
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
|
|
||||||
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
|
|
||||||
|
|
||||||
outputs = model(
|
|
||||||
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offset the logits by one to align with the labels
|
|
||||||
logits = outputs.logits[:, :-1, :]
|
|
||||||
hidden_states = outputs.hidden_states[-1][:, :-1, :]
|
|
||||||
labels = input_ids[:, 1:].clone()
|
|
||||||
loss_mask = loss_mask[:, 1:].bool()
|
|
||||||
|
|
||||||
# if self.use_num_logits_to_keep:
|
|
||||||
# # Align labels with logits
|
|
||||||
# # logits: -, -, [x2, x3, x4, x5, x6]
|
|
||||||
# # ^ --------- ^ after logits[:, :-1, :]
|
|
||||||
# # labels: [y0, y1, y2, y3, y4, y5, y6]
|
|
||||||
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
|
|
||||||
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
|
|
||||||
# labels = labels[:, -num_logits_to_keep:]
|
|
||||||
# loss_mask = loss_mask[:, -num_logits_to_keep:]
|
|
||||||
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
|
||||||
|
|
||||||
if logits.shape[:2] != labels.shape[:2]:
|
|
||||||
# for llava, the returned logits include the image tokens (placed before the text tokens)
|
|
||||||
seq_len = labels.shape[1]
|
|
||||||
logits = logits[:, -seq_len:]
|
|
||||||
hidden_states = hidden_states[:, -seq_len:]
|
|
||||||
|
|
||||||
# Compute the log probabilities of the labels
|
|
||||||
labels[
|
|
||||||
~loss_mask
|
|
||||||
] = 0 # dummy token; we'll ignore the losses on these tokens later
|
|
||||||
per_token_logps = torch.gather(
|
|
||||||
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
|
||||||
).squeeze(2)
|
|
||||||
per_token_logps[~loss_mask] = 0
|
|
||||||
all_logps = per_token_logps.sum(-1)
|
|
||||||
|
|
||||||
output = {}
|
|
||||||
|
|
||||||
if self.use_weighting:
|
|
||||||
with torch.no_grad():
|
|
||||||
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
|
|
||||||
logprobs = F.log_softmax(logits, dim=-1)
|
|
||||||
weights_adjustment_factor = torch.logsumexp(
|
|
||||||
2 * logprobs, dim=-1
|
|
||||||
) # same as sum(probs**2) in log space
|
|
||||||
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
|
|
||||||
all_weights = (per_token_logps_adjusted * loss_mask).sum(
|
|
||||||
-1
|
|
||||||
) / loss_mask.sum(-1)
|
|
||||||
chosen_weights = all_weights[:num_examples]
|
|
||||||
rejected_weights = all_weights[num_examples:]
|
|
||||||
output["policy_weights"] = torch.clamp(
|
|
||||||
torch.exp(chosen_weights + rejected_weights), max=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.rpo_alpha is not None:
|
|
||||||
# Only use the chosen logits for the RPO loss
|
|
||||||
chosen_logits = logits[:num_examples]
|
|
||||||
chosen_labels = labels[:num_examples]
|
|
||||||
|
|
||||||
# Compute the log probabilities of the labels
|
|
||||||
output["nll_loss"] = F.cross_entropy(
|
|
||||||
torch.flatten(chosen_logits, end_dim=1),
|
|
||||||
torch.flatten(chosen_labels, end_dim=1),
|
|
||||||
ignore_index=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.loss_type == "ipo":
|
|
||||||
all_logps = all_logps / loss_mask.sum(-1)
|
|
||||||
|
|
||||||
output["chosen_logps"] = all_logps[:num_examples]
|
|
||||||
output["rejected_logps"] = all_logps[num_examples:]
|
|
||||||
output["mean_chosen_logits"] = logits[:num_examples][
|
|
||||||
loss_mask[:num_examples]
|
|
||||||
].mean()
|
|
||||||
output["mean_rejected_logits"] = logits[num_examples:][
|
|
||||||
loss_mask[num_examples:]
|
|
||||||
].mean()
|
|
||||||
output["hidden_states"] = hidden_states
|
|
||||||
output["labels"] = labels
|
|
||||||
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
output["aux_loss"] = outputs.aux_loss
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -2480,14 +2172,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
report_to = []
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
report_to.append("wandb")
|
|
||||||
if self.cfg.wandb_name:
|
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
|
||||||
|
|
||||||
training_args_kwargs["report_to"] = report_to
|
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
HF_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_MODEL_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
LCE_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_LCE_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_forward(cls):
|
||||||
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<forward>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_act_offloading(cls):
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
|
)
|
||||||
|
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
|
from liger_kernel.transformers.model.llama import (
|
||||||
|
lce_forward as llama_lce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_act_offloading:
|
||||||
|
lce_source = inspect.getsource(llama_lce_forward)
|
||||||
|
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_hf_ga(cls):
|
||||||
|
# bugfix patch for gradient accumulation
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||||
|
)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_auto_model():
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
AxolotlLlamaForCausalLM.set_forward()
|
||||||
|
|
||||||
|
return AxolotlLlamaForCausalLM
|
||||||
@@ -679,6 +679,7 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
activation_offloading: Optional[bool] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||||
|
if self.cfg.activation_offloading:
|
||||||
|
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
patch_training_loop_for_fsdp,
|
patch_training_loop_for_fsdp,
|
||||||
@@ -1183,6 +1192,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
|
# self.apply_patches_to_model()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user