Compare commits
5 Commits
liger-dpo
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed |
@@ -12,7 +12,7 @@ liger-kernel==0.4.2
|
||||
|
||||
packaging==23.2
|
||||
peft==0.14.0
|
||||
transformers==4.47.0
|
||||
transformers>=4.46.3
|
||||
tokenizers>=0.20.1
|
||||
accelerate==1.2.0
|
||||
datasets==3.1.0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
"""
|
||||
CPO config for CPO training
|
||||
"""
|
||||
|
||||
simpo_gamma: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "simpo gamma parameter"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
@@ -36,8 +36,6 @@ class LigerArgs(BaseModel):
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
liger_pref_rl: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
"""
|
||||
integration of liger dpo kernels with dpotrainer
|
||||
"""
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class AxolotlLigerDPOTrainer(AxolotlDPOTrainer):
|
||||
"""
|
||||
Extend the DPO Trainer to use LIGER kernels for DPO
|
||||
"""
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
):
|
||||
"""
|
||||
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together,
|
||||
and compute the DPO loss using Liger's fused kernel.
|
||||
|
||||
This method replaces the original `concatenated_forward` implementation to use Liger.
|
||||
"""
|
||||
|
||||
# Prepare concatenated inputs
|
||||
concatenated_batch = self.concatenated_inputs(batch, self.padding_value)
|
||||
|
||||
# Extract concatenated inputs
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
|
||||
# For encoder-decoder models, you'd need to construct decoder_input_ids, etc.
|
||||
# This example assumes a causal decoder-only model.
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
|
||||
# Align inputs by removing leading padding
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
|
||||
|
||||
# Remove trailing empty columns
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
if empty_cols.any():
|
||||
first_empty_col = torch.nonzero(empty_cols)[0].item()
|
||||
input_ids = input_ids[:, :first_empty_col]
|
||||
attention_mask = attention_mask[:, :first_empty_col]
|
||||
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
|
||||
# Labels are completion_input_ids shifted by one token right
|
||||
# For causal LM, labels are the completion part only
|
||||
labels = torch.cat(
|
||||
(torch.zeros_like(prompt_input_ids), completion_input_ids), dim=1
|
||||
)
|
||||
labels = labels[:, 1:] # shift left by one
|
||||
attention_mask = attention_mask[:, 1:]
|
||||
labels = labels[:, : attention_mask.size(1)]
|
||||
|
||||
# Mask out the prompt portion from loss
|
||||
labels[~attention_mask.bool()] = self.label_pad_token_id
|
||||
|
||||
# Prepare reference model hidden states if ref_model exists
|
||||
use_ref_model = self.ref_model is not None and not self.reference_free
|
||||
|
||||
# Run main model forward to get hidden states
|
||||
# If using FSDP, redirect forward calls
|
||||
if isinstance(model, FullyShardedDataParallel):
|
||||
outputs = _FSDPForwardRedirection()(
|
||||
model,
|
||||
model._fsdp_wrapped_module.model, # pylint: disable=protected-access
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
# If model is a DataParallel, unwrap
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
outputs = model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
ref_last_hidden_state = None
|
||||
if use_ref_model:
|
||||
ref_model = self.ref_model
|
||||
if isinstance(ref_model, FullyShardedDataParallel):
|
||||
with torch.no_grad():
|
||||
ref_outputs = _FSDPForwardRedirection()(
|
||||
ref_model,
|
||||
ref_model._fsdp_wrapped_module.model, # pylint: disable=protected-accessåå
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
)
|
||||
else:
|
||||
if isinstance(ref_model, torch.nn.DataParallel):
|
||||
ref_model = ref_model.module
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model.model(
|
||||
input_ids, attention_mask=attention_mask, use_cache=False
|
||||
)
|
||||
ref_last_hidden_state = ref_outputs.last_hidden_state
|
||||
|
||||
# Retrieve lm_head parameters
|
||||
lm_head = model.lm_head
|
||||
ref_lm_head = (
|
||||
self.ref_model.lm_head
|
||||
if (use_ref_model and self.ref_model is not None)
|
||||
else None
|
||||
)
|
||||
|
||||
# Use Liger fused DPO loss
|
||||
dpo_loss_fn = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=False,
|
||||
compiled=True,
|
||||
use_ref_model=use_ref_model,
|
||||
)
|
||||
|
||||
# call fused Liger DPO
|
||||
if use_ref_model:
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight, # lin_weight
|
||||
last_hidden_state, # _input
|
||||
labels, # target
|
||||
bias=lm_head.bias,
|
||||
ref_input=ref_last_hidden_state,
|
||||
ref_weight=ref_lm_head.weight,
|
||||
ref_bias=ref_lm_head.bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
else:
|
||||
# No reference model scenario: Liger kernel treats ref_logps as 0
|
||||
loss_acc, aux_outputs = dpo_loss_fn(
|
||||
lm_head.weight,
|
||||
last_hidden_state,
|
||||
labels,
|
||||
bias=lm_head.bias,
|
||||
)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = aux_outputs[:5]
|
||||
|
||||
# Add aux loss if enabled
|
||||
if self.aux_loss_enabled and hasattr(outputs, "aux_loss"):
|
||||
loss_acc = loss_acc + self.aux_loss_coef * outputs.aux_loss
|
||||
|
||||
# Add RPO loss if requested (RPO is a variant that adds NLL loss)
|
||||
if self.args.rpo_alpha is not None:
|
||||
# policy_nll_loss: average negative log-likelihood of chosen completions
|
||||
loss_acc = loss_acc + self.args.rpo_alpha * policy_nll_loss.mean()
|
||||
|
||||
return (
|
||||
loss_acc,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""
|
||||
Compute the DPO loss and other metrics for a given batch using the Liger fused kernel.
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
loss,
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
# For metrics, we approximate chosen/rejected rewards as beta * (log π(y) - log π_ref(y)) if ref model used.
|
||||
# If no ref model is used, we can't compute reward_accuracies meaningfully. For simplicity, we assume ref_model presence.
|
||||
if self.ref_model is not None and not self.reference_free:
|
||||
# If you want full parity with original DPOTrainer metrics (like chosen_rewards, rejected_rewards),
|
||||
# you'd need to run reference forward or store reference log ps. The Liger kernel currently doesn't
|
||||
# return ref_chosen_logps/ref_rejected_logps explicitly. By design, Liger directly computes DPO.
|
||||
#
|
||||
# Here we approximate chosen_rewards and rejected_rewards from the difference in chosen/rejected logps.
|
||||
# Since Liger DPO does not output ref logps separately, you may need to modify the Liger kernel to
|
||||
# also output them if you need all the metrics. For now, we'll skip them or provide a placeholder.
|
||||
|
||||
# Placeholder: chosen/rejected "rewards" can't be retrieved directly from Liger as-is.
|
||||
# If needed, integrate ref_chosen_logps/ref_rejected_logps into Liger kernel returns.
|
||||
chosen_rewards = policy_chosen_logps * self.beta # approximation
|
||||
rejected_rewards = policy_rejected_logps * self.beta # approximation
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.mean().cpu().item()
|
||||
metrics[f"{train_eval}_rewards/rejected"] = (
|
||||
rejected_rewards.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/accuracies"] = (
|
||||
reward_accuracies.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_rewards/margins"] = (
|
||||
(chosen_rewards - rejected_rewards).mean().cpu().item()
|
||||
)
|
||||
|
||||
metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.mean().cpu().item()
|
||||
metrics[f"{train_eval}_logps/rejected"] = (
|
||||
policy_rejected_logps.mean().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/chosen"] = (
|
||||
policy_chosen_logits_mean.detach().cpu().item()
|
||||
)
|
||||
metrics[f"{train_eval}_logits/rejected"] = (
|
||||
policy_rejected_logits_mean.detach().cpu().item()
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
metrics[f"{train_eval}_nll_loss"] = (
|
||||
policy_nll_loss.mean().detach().cpu().item()
|
||||
)
|
||||
|
||||
return loss.mean(), metrics
|
||||
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from torchtune.training import OffloadActivations
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
HF_MODEL_OUTPUTS = """
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_MODEL_OUTPUTS = """
|
||||
with self.act_offloading_ctx_manager:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
LCE_MODEL_OUTPUTS = """
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_LCE_OUTPUTS = """
|
||||
with self.act_offloading_ctx_manager:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
HF_GA_FORWARD_1 = """
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_GA_FORWARD_1 = """
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
""".lstrip()
|
||||
|
||||
HF_GA_FORWARD_2 = """
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_GA_FORWARD_2 = """
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||
""".lstrip()
|
||||
|
||||
|
||||
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@classmethod
|
||||
def set_forward(cls):
|
||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<forward>", "exec"), cls
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_act_offloading(cls):
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||
)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||
)
|
||||
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||
|
||||
@classmethod
|
||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
|
||||
if enable_act_offloading:
|
||||
lce_source = inspect.getsource(llama_lce_forward)
|
||||
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||
cls,
|
||||
)
|
||||
else:
|
||||
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||
|
||||
@classmethod
|
||||
def patch_hf_ga(cls):
|
||||
# bugfix patch for gradient accumulation
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||
)
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||
)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||
)
|
||||
|
||||
|
||||
def replace_auto_model():
|
||||
from transformers import LlamaConfig
|
||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||
AxolotlLlamaForCausalLM.set_forward()
|
||||
|
||||
return AxolotlLlamaForCausalLM
|
||||
@@ -8,36 +8,17 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -8,37 +8,17 @@ def argilla(
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
# pylint: disable=duplicate-code
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -679,6 +679,7 @@ class AxolotlInputConfig(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
activation_offloading: Optional[bool] = None
|
||||
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
|
||||
|
||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||
|
||||
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||
|
||||
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||
if self.cfg.activation_offloading:
|
||||
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||
|
||||
if self.cfg.fsdp:
|
||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
patch_training_loop_for_fsdp,
|
||||
@@ -1183,6 +1192,8 @@ class ModelLoader:
|
||||
|
||||
self.apply_lora_patch()
|
||||
|
||||
# self.apply_patches_to_model()
|
||||
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user