Compare commits
6 Commits
patch_lora
...
grpo_liger
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a09d5e844 | ||
|
|
cf61b4aba7 | ||
|
|
14d274efe6 | ||
|
|
954e192f38 | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 |
@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -30,31 +31,21 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
training_kwargs = [
|
||||
"use_vllm",
|
||||
"vllm_device",
|
||||
"vllm_gpu_memory_utilization",
|
||||
"vllm_max_model_len",
|
||||
"vllm_dtype",
|
||||
"use_liger_loss",
|
||||
"num_generations",
|
||||
"log_completions",
|
||||
"sync_ref_model",
|
||||
"ref_model_mixup_alpha",
|
||||
"ref_model_sync_steps",
|
||||
"max_completion_length",
|
||||
]
|
||||
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
@@ -71,9 +62,7 @@ class GRPOStrategy:
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
use_liger_loss: bool = False
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
from transformers.utils import is_liger_kernel_available
|
||||
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
|
||||
|
||||
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from accelerate.utils import broadcast_object_list, gather_object
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
@@ -20,7 +31,20 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
self.use_liger_loss = kwargs["args"].use_liger_loss
|
||||
if self.use_liger_loss:
|
||||
if not is_liger_kernel_available():
|
||||
raise ValueError(
|
||||
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
||||
"Please install liger-kernel first: `pip install liger-kernel`"
|
||||
)
|
||||
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
|
||||
beta=self.beta,
|
||||
compiled=is_compiled_module(self.model),
|
||||
use_ref_model=True,
|
||||
num_generations=self.args.num_generations,
|
||||
)
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
@@ -72,26 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
@@ -101,7 +110,217 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
if self.use_liger_loss:
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||
prompt_inputs = self.processing_class(
|
||||
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
|
||||
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text) * self.num_generations
|
||||
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
||||
# corresponding slice.
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts) * self.num_generations,
|
||||
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
||||
prompt_inputs_repeated = torch.repeat_interleave(
|
||||
prompt_inputs["input_ids"], self.num_generations, dim=0
|
||||
)
|
||||
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
|
||||
else:
|
||||
# Regular generation path
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
**prompt_inputs, generation_config=self.generation_config
|
||||
)
|
||||
|
||||
prompt_length = prompt_inputs["input_ids"].size(1)
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def get_per_token_logps(model, input_ids, num_logits_to_keep):
|
||||
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
|
||||
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
|
||||
hidden_states = outputs.last_hidden_state[:, :-1]
|
||||
logits = outputs.logits # (B, L, V)
|
||||
logits = logits[
|
||||
:, :-1, :
|
||||
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
||||
per_token_logps = []
|
||||
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
|
||||
log_probs = logits_row.log_softmax(dim=-1)
|
||||
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
||||
per_token_logps.append(token_log_prob)
|
||||
return torch.stack(per_token_logps), hidden_states
|
||||
|
||||
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
|
||||
|
||||
with torch.inference_mode():
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||
self.ref_model, prompt_completion_ids, num_logits_to_keep
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(model).disable_adapter():
|
||||
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||
model, prompt_completion_ids, num_logits_to_keep
|
||||
)
|
||||
|
||||
# done in liger
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
# per_token_kl = (
|
||||
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
# )
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# Decode the generated completions
|
||||
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
||||
|
||||
# Compute the rewards
|
||||
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
|
||||
|
||||
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
||||
for i, (reward_func, reward_processing_class) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
reward_inputs = reward_processing_class(
|
||||
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
||||
)
|
||||
reward_inputs = super()._prepare_inputs(reward_inputs)
|
||||
with torch.inference_mode():
|
||||
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
||||
else:
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
||||
for key in reward_kwargs:
|
||||
for example in inputs:
|
||||
# Repeat each value in the column for `num_generations` times
|
||||
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
||||
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
||||
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
||||
|
||||
# Sum the rewards from all reward functions
|
||||
rewards = rewards_per_func.sum(dim=1)
|
||||
|
||||
# done in liger
|
||||
# # Compute grouped-wise rewards
|
||||
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
|
||||
# done in liger
|
||||
# # Normalize the rewards to compute the advantages
|
||||
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# done in liger
|
||||
# x - x.detach() allows for preserving gradients from x
|
||||
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
||||
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
|
||||
# Log the metrics
|
||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||
self._metrics["completion_length"].append(completion_length)
|
||||
|
||||
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
|
||||
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
||||
|
||||
lm_head = model.get_output_embeddings()
|
||||
|
||||
if self.ref_model is not None:
|
||||
ref_lm_head = self.ref_model.get_output_embeddings()
|
||||
else:
|
||||
with self.null_ref_context():
|
||||
ref_lm_head = model.get_output_embeddings()
|
||||
ref_weight = ref_lm_head.weight
|
||||
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
|
||||
|
||||
loss, metrics = self.grpo_loss_fn(
|
||||
lm_head,
|
||||
hidden_states, # this is the hidden states from the model
|
||||
completion_mask,
|
||||
rewards,
|
||||
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
|
||||
ref_input=ref_hidden_states, # this is the hidden states from the ref model
|
||||
ref_weight=ref_weight,
|
||||
ref_bias=ref_bias,
|
||||
)
|
||||
else:
|
||||
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
|
||||
|
||||
@contextmanager
|
||||
def null_ref_context(self):
|
||||
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
||||
with (
|
||||
self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
if self.is_peft_model and not self.ref_adapter_name
|
||||
else nullcontext()
|
||||
):
|
||||
if self.ref_adapter_name:
|
||||
self.model.set_adapter(self.ref_adapter_name)
|
||||
yield
|
||||
if self.ref_adapter_name:
|
||||
self.model.set_adapter(self.model_adapter_name or "default")
|
||||
|
||||
@@ -4,12 +4,13 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
@@ -95,90 +96,108 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
||||
return attn_output
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(model: PreTrainedModel):
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
model: A HuggingFace transformers model.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
# Find all attention modules in the model
|
||||
attention_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
if "attention" in module.__class__.__name__.lower()
|
||||
and hasattr(module, "forward")
|
||||
]
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
if not attention_modules:
|
||||
LOG.warning("No attention modules found in model")
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
return
|
||||
|
||||
attention_classes = {type(module) for module in attention_modules}
|
||||
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
for attention_cls in attention_classes:
|
||||
# Skip if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
continue
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
|
||||
# Get and store original forward implementation
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Remove indentation
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Verify required code blocks exist
|
||||
assert (
|
||||
ORIGINAL_QKV_CODE in self_attn_forward
|
||||
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||
assert (
|
||||
ORIGINAL_O_CODE in self_attn_forward
|
||||
), f"Original O code not found in {attention_cls.__name__}"
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Replace code blocks
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||
)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Import necessary symbols from the attention module
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
if items_to_import:
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Execute the new implementation
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
|
||||
@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
|
||||
sync_ref_model: Optional[bool] = False
|
||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||
ref_model_sync_steps: Optional[int] = 64
|
||||
use_liger_loss: Optional[bool] = False
|
||||
|
||||
@@ -439,6 +439,11 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -1023,12 +1028,6 @@ class ModelLoader:
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.model)
|
||||
|
||||
if (
|
||||
self.cfg.lora_mlp_kernel
|
||||
or self.cfg.lora_qkv_kernel
|
||||
@@ -1182,7 +1181,6 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
# TODO: Deprecate this.
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
@@ -1203,7 +1201,9 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
|
||||
@@ -5,12 +5,12 @@ import numpy as np
|
||||
|
||||
|
||||
def get_dataset_lengths(dataset):
|
||||
if "length" in dataset.data.column_names:
|
||||
lengths = np.array(dataset.data.column("length"))
|
||||
elif "position_ids" in dataset.data.column_names:
|
||||
position_ids = dataset.data.column("position_ids")
|
||||
if "length" in dataset.column_names:
|
||||
lengths = np.array(dataset["length"])
|
||||
elif "position_ids" in dataset.column_names:
|
||||
position_ids = dataset["position_ids"]
|
||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
input_ids = dataset["input_ids"]
|
||||
lengths = np.array([len(seq) for seq in input_ids])
|
||||
return lengths
|
||||
|
||||
@@ -9,14 +9,16 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
@@ -63,45 +65,15 @@ def small_llama_model():
|
||||
return LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture
|
||||
def minimal_cfg():
|
||||
"Config of real HuggingFace Hub model"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def test_attention_patching_integration(minimal_cfg):
|
||||
def test_attention_patching_integration():
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(LlamaAttention, "forward")
|
||||
|
||||
# Load model
|
||||
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
|
||||
# Get the new forward method
|
||||
patched_forward = LlamaAttention.forward
|
||||
@@ -404,10 +376,38 @@ def test_model_architecture(model_config):
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration(minimal_cfg):
|
||||
def test_kernel_training_integration():
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Load model
|
||||
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = model.model.model.layers[0]
|
||||
|
||||
@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_smollm2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||
def fixture_mistralv03_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Tests for loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_dpo_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"rl": "dpo",
|
||||
"learning_rate": 0.000001,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"sequence_len": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestDPOChatml:
|
||||
"""
|
||||
Test loading DPO preference datasets with chatml formatting
|
||||
"""
|
||||
|
||||
def test_default(self, minimal_dpo_cfg):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "chatml",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
]
|
||||
}
|
||||
| minimal_dpo_cfg
|
||||
)
|
||||
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
assert "rejected" in train_ds[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -18,11 +19,6 @@ def fixture_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
|
||||
Reference in New Issue
Block a user