KD dataset loading and KD with logprobs
This commit is contained in:
@@ -16,7 +16,11 @@ from typing import List, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
TrainerCallback,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
@@ -29,6 +33,7 @@ from axolotl.core.trainers.base import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.kd import AxolotlKDTrainer
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlDPOConfig,
|
||||
@@ -38,7 +43,6 @@ from axolotl.core.training_args import (
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -282,6 +286,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.trainer == "kd":
|
||||
return AxolotlKDTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -988,10 +994,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
if self.cfg.liger_pref_rl:
|
||||
trainer_cls = AxolotlLigerDPOTrainer
|
||||
else:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
|
||||
@@ -176,6 +176,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
112
src/axolotl/core/trainers/kd.py
Normal file
112
src/axolotl/core/trainers/kd.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
def kd_loss_function(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
|
||||
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
|
||||
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
|
||||
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
# Slice the student logits to match the teacher-provided seq length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, -teacher_seq_len:, :
|
||||
] # Now [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Convert student top-K logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
# teacher_probs are simply exp of teacher_logprobs (already scaled)
|
||||
teacher_probs = target_logprobs.exp()
|
||||
|
||||
# Compute forward KL
|
||||
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
|
||||
kd_loss_per_position = (
|
||||
teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
).sum(
|
||||
dim=-1
|
||||
) # [B, teacher_seq_len]
|
||||
|
||||
# gradient accumulation fixes
|
||||
if num_items_in_batch:
|
||||
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
|
||||
else:
|
||||
kd_loss = kd_loss_per_position.mean() # Scalar
|
||||
|
||||
return kd_loss
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
if self._signature_columns:
|
||||
if "target_logprobs" not in self._signature_columns:
|
||||
columns_to_add.append("target_logprobs")
|
||||
if "target_token_ids" not in self._signature_columns:
|
||||
columns_to_add.append("target_token_ids")
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
student_logits = outputs["logits"]
|
||||
loss_kd = kd_loss_function(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss_kd *= self.accelerator.num_processes
|
||||
|
||||
return (loss_kd, outputs) if return_outputs else loss_kd
|
||||
@@ -5,6 +5,7 @@ HF Chat Templates prompt strategy
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
@@ -459,6 +460,84 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
Handle fields for logprob KD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
temperature=1.0,
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.temperature = temperature
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
)
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
# Convert to tensor
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Apply temperature scaling at data load time
|
||||
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
||||
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
||||
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
||||
position_logprobs_tensor, dim=0, keepdim=True
|
||||
)
|
||||
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
|
||||
return sample
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||
return self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
@@ -491,7 +570,16 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
strategy_cls = ChatTemplateStrategy
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["temperature"] = temperature
|
||||
|
||||
if cfg.trainer == "kd" or logprobs_field:
|
||||
strategy_cls = ChatTemplateStrategyWithKD
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
|
||||
@@ -176,6 +176,7 @@ class SFTDataset(BaseModel):
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
@@ -613,6 +614,8 @@ class AxolotlInputConfig(
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
trainer: Optional[Literal["kd"]] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
shuffle_merged_datasets: Optional[bool] = True
|
||||
|
||||
Reference in New Issue
Block a user