KD dataset loading and KD with logprobs
This commit is contained in:
@@ -16,7 +16,11 @@ from typing import List, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening
|
from transformers import (
|
||||||
|
DataCollatorWithFlattening,
|
||||||
|
EarlyStoppingCallback,
|
||||||
|
TrainerCallback,
|
||||||
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
@@ -29,6 +33,7 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainers.kd import AxolotlKDTrainer
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
AxolotlDPOConfig,
|
||||||
@@ -38,7 +43,6 @@ from axolotl.core.training_args import (
|
|||||||
AxolotlTrainingArguments,
|
AxolotlTrainingArguments,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -282,6 +286,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
return AxolotlRewardTrainer
|
return AxolotlRewardTrainer
|
||||||
|
if self.cfg.trainer == "kd":
|
||||||
|
return AxolotlKDTrainer
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -988,10 +994,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
if self.cfg.liger_pref_rl:
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls = AxolotlLigerDPOTrainer
|
|
||||||
else:
|
|
||||||
trainer_cls = AxolotlDPOTrainer
|
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
|
self._signature_columns = None # workaround for pylint
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|||||||
112
src/axolotl/core/trainers/kd.py
Normal file
112
src/axolotl/core/trainers/kd.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
KD trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def kd_loss_function(
|
||||||
|
student_logits,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
|
||||||
|
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
|
||||||
|
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
|
||||||
|
|
||||||
|
teacher_seq_len = target_token_ids.shape[1]
|
||||||
|
|
||||||
|
# Slice the student logits to match the teacher-provided seq length
|
||||||
|
student_logits_for_kd = student_logits[
|
||||||
|
:, -teacher_seq_len:, :
|
||||||
|
] # Now [B, teacher_seq_len, vocab_size]
|
||||||
|
|
||||||
|
# Gather student logits for teacher's top-K tokens
|
||||||
|
student_logits_topk = torch.gather(
|
||||||
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
|
# Convert student top-K logits to logprobs
|
||||||
|
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||||
|
student_logits_topk, dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# teacher_probs are simply exp of teacher_logprobs (already scaled)
|
||||||
|
teacher_probs = target_logprobs.exp()
|
||||||
|
|
||||||
|
# Compute forward KL
|
||||||
|
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
|
||||||
|
kd_loss_per_position = (
|
||||||
|
teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||||
|
).sum(
|
||||||
|
dim=-1
|
||||||
|
) # [B, teacher_seq_len]
|
||||||
|
|
||||||
|
# gradient accumulation fixes
|
||||||
|
if num_items_in_batch:
|
||||||
|
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
|
||||||
|
else:
|
||||||
|
kd_loss = kd_loss_per_position.mean() # Scalar
|
||||||
|
|
||||||
|
return kd_loss
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _set_signature_columns_if_needed(self):
|
||||||
|
super()._set_signature_columns_if_needed()
|
||||||
|
columns_to_add = []
|
||||||
|
if self._signature_columns:
|
||||||
|
if "target_logprobs" not in self._signature_columns:
|
||||||
|
columns_to_add.append("target_logprobs")
|
||||||
|
if "target_token_ids" not in self._signature_columns:
|
||||||
|
columns_to_add.append("target_token_ids")
|
||||||
|
if columns_to_add:
|
||||||
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|
||||||
|
Subclass and override for custom behavior.
|
||||||
|
"""
|
||||||
|
target_logprobs = inputs.pop("target_logprobs")
|
||||||
|
target_token_ids = inputs.pop("target_token_ids")
|
||||||
|
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss_kwargs = {}
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
|
inputs = {**inputs, **loss_kwargs}
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
student_logits = outputs["logits"]
|
||||||
|
loss_kd = kd_loss_function(
|
||||||
|
student_logits,
|
||||||
|
target_token_ids,
|
||||||
|
target_logprobs,
|
||||||
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save past state if it exists
|
||||||
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.args.past_index
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||||
|
loss_kd *= self.accelerator.num_processes
|
||||||
|
|
||||||
|
return (loss_kd, outputs) if return_outputs else loss_kd
|
||||||
@@ -5,6 +5,7 @@ HF Chat Templates prompt strategy
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
@@ -459,6 +460,84 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return prompt.get(self.images, None)
|
return prompt.get(self.images, None)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||||
|
"""
|
||||||
|
Handle fields for logprob KD
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
train_on_inputs,
|
||||||
|
sequence_len,
|
||||||
|
roles_to_train=None,
|
||||||
|
train_on_eos=None,
|
||||||
|
logprobs_field="logprobs",
|
||||||
|
temperature=1.0,
|
||||||
|
):
|
||||||
|
self.logprobs_field = logprobs_field
|
||||||
|
self.temperature = temperature
|
||||||
|
super().__init__(
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
train_on_inputs,
|
||||||
|
sequence_len,
|
||||||
|
roles_to_train=roles_to_train,
|
||||||
|
train_on_eos=train_on_eos,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_logprobs(self, sample):
|
||||||
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
|
target_logprobs = []
|
||||||
|
target_token_ids = []
|
||||||
|
|
||||||
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
|
# Initialize collections for logprobs and token_ids
|
||||||
|
position_logprobs = []
|
||||||
|
position_token_ids = []
|
||||||
|
|
||||||
|
# Process each token probability entry
|
||||||
|
for entry in token_pos_logprobs:
|
||||||
|
# Extract logprob value
|
||||||
|
logprob = entry["logprob"]
|
||||||
|
|
||||||
|
# Parse token_id from the "token_id:###" format
|
||||||
|
token_id = int(entry["token"].split(":")[1])
|
||||||
|
|
||||||
|
# Append to our collections
|
||||||
|
position_logprobs.append(logprob)
|
||||||
|
position_token_ids.append(token_id)
|
||||||
|
|
||||||
|
# Convert to a tensor for easier manipulation
|
||||||
|
# Convert to tensor
|
||||||
|
position_logprobs_tensor = torch.tensor(
|
||||||
|
position_logprobs, dtype=torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply temperature scaling at data load time
|
||||||
|
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
||||||
|
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
||||||
|
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
||||||
|
position_logprobs_tensor, dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||||
|
|
||||||
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
|
# Update sample with transformed logprobs
|
||||||
|
sample["target_logprobs"] = target_logprobs
|
||||||
|
sample["target_token_ids"] = target_token_ids
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def tokenize_prompt(self, prompt):
|
||||||
|
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||||
|
return self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
ds_cfg = ds_cfg or {}
|
ds_cfg = ds_cfg or {}
|
||||||
@@ -491,7 +570,16 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
|||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy_cls = ChatTemplateStrategy
|
||||||
|
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||||
|
strategy_params["logprobs_field"] = logprobs_field
|
||||||
|
if temperature := ds_cfg.get("temperature"):
|
||||||
|
strategy_params["temperature"] = temperature
|
||||||
|
|
||||||
|
if cfg.trainer == "kd" or logprobs_field:
|
||||||
|
strategy_cls = ChatTemplateStrategyWithKD
|
||||||
|
|
||||||
|
strategy = strategy_cls(
|
||||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class SFTDataset(BaseModel):
|
|||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
message_field_training: Optional[str] = None
|
message_field_training: Optional[str] = None
|
||||||
message_field_training_detail: Optional[str] = None
|
message_field_training_detail: Optional[str] = None
|
||||||
|
logprobs_field: Optional[str] = None
|
||||||
roles_to_train: Optional[List[str]] = None
|
roles_to_train: Optional[List[str]] = None
|
||||||
train_on_eos: Optional[str] = None
|
train_on_eos: Optional[str] = None
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
@@ -621,6 +622,8 @@ class AxolotlInputConfig(
|
|||||||
bool
|
bool
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
|
||||||
|
trainer: Optional[Literal["kd"]] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
shuffle_merged_datasets: Optional[bool] = True
|
shuffle_merged_datasets: Optional[bool] = True
|
||||||
|
|||||||
Reference in New Issue
Block a user