KD dataset loading and KD with logprobs

This commit is contained in:
Wing Lian
2024-12-18 15:16:45 -05:00
parent 02c9898a95
commit 39daeb2c79
5 changed files with 214 additions and 7 deletions

View File

@@ -16,7 +16,11 @@ from typing import List, Type, Union
import torch
import transformers
from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
TrainerCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -29,6 +33,7 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.kd import AxolotlKDTrainer
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
@@ -38,7 +43,6 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -282,6 +286,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.trainer == "kd":
return AxolotlKDTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -988,10 +994,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
if self.cfg.liger_pref_rl:
trainer_cls = AxolotlLigerDPOTrainer
else:
trainer_cls = AxolotlDPOTrainer
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer

View File

@@ -176,6 +176,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))

View File

@@ -0,0 +1,112 @@
"""
KD trainer
"""
from typing import Optional
import torch
from axolotl.core.trainers.base import AxolotlTrainer
def kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
num_items_in_batch: Optional[int] = None,
**kwargs, # pylint: disable=unused-argument
):
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
teacher_seq_len = target_token_ids.shape[1]
# Slice the student logits to match the teacher-provided seq length
student_logits_for_kd = student_logits[
:, -teacher_seq_len:, :
] # Now [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Convert student top-K logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
)
# teacher_probs are simply exp of teacher_logprobs (already scaled)
teacher_probs = target_logprobs.exp()
# Compute forward KL
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
kd_loss_per_position = (
teacher_probs * (target_logprobs - student_logprobs_topk)
).sum(
dim=-1
) # [B, teacher_seq_len]
# gradient accumulation fixes
if num_items_in_batch:
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
else:
kd_loss = kd_loss_per_position.mean() # Scalar
return kd_loss
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
if self._signature_columns:
if "target_logprobs" not in self._signature_columns:
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if columns_to_add:
self._signature_columns += columns_to_add
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
student_logits = outputs["logits"]
loss_kd = kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
num_items_in_batch=num_items_in_batch,
)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss_kd *= self.accelerator.num_processes
return (loss_kd, outputs) if return_outputs else loss_kd

View File

@@ -5,6 +5,7 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
import torch
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
@@ -459,6 +460,84 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return prompt.get(self.images, None)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
Handle fields for logprob KD
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
temperature=1.0,
):
self.logprobs_field = logprobs_field
self.temperature = temperature
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
)
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_logprobs = []
target_token_ids = []
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
# Apply temperature scaling at data load time
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
position_logprobs_tensor = position_logprobs_tensor / self.temperature
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
position_logprobs_tensor, dim=0, keepdim=True
)
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
return sample
def tokenize_prompt(self, prompt):
tokenized_prompt = super().tokenize_prompt(prompt)
return self.transform_logprobs(tokenized_prompt)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
@@ -491,7 +570,16 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
strategy_cls = ChatTemplateStrategy
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if temperature := ds_cfg.get("temperature"):
strategy_params["temperature"] = temperature
if cfg.trainer == "kd" or logprobs_field:
strategy_cls = ChatTemplateStrategyWithKD
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

View File

@@ -176,6 +176,7 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
@@ -613,6 +614,8 @@ class AxolotlInputConfig(
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
trainer: Optional[Literal["kd"]] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True