KD dataset loading and KD with logprobs

This commit is contained in:
Wing Lian
2024-12-18 15:16:45 -05:00
parent 88b3198894
commit 303cfa71aa
5 changed files with 214 additions and 7 deletions

View File

@@ -16,7 +16,11 @@ from typing import List, Type, Union
import torch import torch
import transformers import transformers
from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
TrainerCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -29,6 +33,7 @@ from axolotl.core.trainers.base import (
AxolotlTrainer, AxolotlTrainer,
ReLoRATrainer, ReLoRATrainer,
) )
from axolotl.core.trainers.kd import AxolotlKDTrainer
from axolotl.core.training_args import ( from axolotl.core.training_args import (
AxolotlCPOConfig, AxolotlCPOConfig,
AxolotlDPOConfig, AxolotlDPOConfig,
@@ -38,7 +43,6 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments, AxolotlTrainingArguments,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
@@ -282,6 +286,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer return AxolotlMambaTrainer
if self.cfg.reward_model: if self.cfg.reward_model:
return AxolotlRewardTrainer return AxolotlRewardTrainer
if self.cfg.trainer == "kd":
return AxolotlKDTrainer
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
@@ -988,10 +994,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]: if self.cfg.rl in ["dpo", "ipo"]:
if self.cfg.liger_pref_rl: trainer_cls = AxolotlDPOTrainer
trainer_cls = AxolotlLigerDPOTrainer
else:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer

View File

@@ -176,6 +176,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))

View File

@@ -0,0 +1,112 @@
"""
KD trainer
"""
from typing import Optional
import torch
from axolotl.core.trainers.base import AxolotlTrainer
def kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
num_items_in_batch: Optional[int] = None,
**kwargs, # pylint: disable=unused-argument
):
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
teacher_seq_len = target_token_ids.shape[1]
# Slice the student logits to match the teacher-provided seq length
student_logits_for_kd = student_logits[
:, -teacher_seq_len:, :
] # Now [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Convert student top-K logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
)
# teacher_probs are simply exp of teacher_logprobs (already scaled)
teacher_probs = target_logprobs.exp()
# Compute forward KL
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
kd_loss_per_position = (
teacher_probs * (target_logprobs - student_logprobs_topk)
).sum(
dim=-1
) # [B, teacher_seq_len]
# gradient accumulation fixes
if num_items_in_batch:
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
else:
kd_loss = kd_loss_per_position.mean() # Scalar
return kd_loss
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
if self._signature_columns:
if "target_logprobs" not in self._signature_columns:
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if columns_to_add:
self._signature_columns += columns_to_add
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
student_logits = outputs["logits"]
loss_kd = kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
num_items_in_batch=num_items_in_batch,
)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss_kd *= self.accelerator.num_processes
return (loss_kd, outputs) if return_outputs else loss_kd

View File

@@ -5,6 +5,7 @@ HF Chat Templates prompt strategy
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch
from transformers import ProcessorMixin from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
@@ -459,6 +460,84 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return prompt.get(self.images, None) return prompt.get(self.images, None)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
Handle fields for logprob KD
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
temperature=1.0,
):
self.logprobs_field = logprobs_field
self.temperature = temperature
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
)
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_logprobs = []
target_token_ids = []
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
# Apply temperature scaling at data load time
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
position_logprobs_tensor = position_logprobs_tensor / self.temperature
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
position_logprobs_tensor, dim=0, keepdim=True
)
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
return sample
def tokenize_prompt(self, prompt):
tokenized_prompt = super().tokenize_prompt(prompt)
return self.transform_logprobs(tokenized_prompt)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
@@ -491,7 +570,16 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
"train_on_eos": ds_cfg.get("train_on_eos", "turn"), "train_on_eos": ds_cfg.get("train_on_eos", "turn"),
} }
strategy = ChatTemplateStrategy( strategy_cls = ChatTemplateStrategy
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if temperature := ds_cfg.get("temperature"):
strategy_params["temperature"] = temperature
if cfg.trainer == "kd" or logprobs_field:
strategy_cls = ChatTemplateStrategyWithKD
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
) )

View File

@@ -177,6 +177,7 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None message_field_content: Optional[str] = None
message_field_training: Optional[str] = None message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
roles_to_train: Optional[List[str]] = None roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
@@ -621,6 +622,8 @@ class AxolotlInputConfig(
bool bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
trainer: Optional[Literal["kd"]] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True shuffle_merged_datasets: Optional[bool] = True