more KD updates
This commit is contained in:
@@ -49,6 +49,7 @@ class KDPlugin(BasePlugin):
|
||||
"kd_alpha": cfg.kd_alpha,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_beta": cfg.kd_beta,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
@@ -72,6 +73,7 @@ class KDPlugin(BasePlugin):
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_online_server": cfg.kd_online_server,
|
||||
"kd_online_timeout": cfg.kd_online_timeout,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
if use_batch_sampler_collator:
|
||||
|
||||
@@ -42,6 +42,9 @@ class KDArgs(BaseModel):
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: bool | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
|
||||
# TODO online kd
|
||||
kd_online_server_base_url: str | None = None
|
||||
@@ -67,3 +70,6 @@ class KDTrainingArgsMixin:
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
from orjson import orjson
|
||||
|
||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
kd_online_server: Optional[str] = "vllm",
|
||||
kd_online_timeout: Optional[int] = 120,
|
||||
kd_cache_dir: Optional[str] = None,
|
||||
kd_normalize_topk: Optional[bool] = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
self.http_session = requests.Session()
|
||||
self.kd_online_timeout = kd_online_timeout
|
||||
self.kd_cache_dir = kd_cache_dir
|
||||
self.kd_normalize_topk = kd_normalize_topk
|
||||
|
||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||
"""
|
||||
@@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||
)
|
||||
|
||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||
if len(raw_logprobs) != self.kd_online_topk:
|
||||
# This case should be rare if pre-padding/truncation is done correctly
|
||||
LOG.warning(
|
||||
f"Logprobs length mismatch in _normalize_logprobs. "
|
||||
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
|
||||
)
|
||||
padded_logprobs = raw_logprobs[: self.kd_online_topk]
|
||||
if len(padded_logprobs) < self.kd_online_topk:
|
||||
padded_logprobs.extend(
|
||||
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
|
||||
)
|
||||
raw_logprobs = padded_logprobs
|
||||
|
||||
try:
|
||||
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||
|
||||
# Convert logprobs at T_online to probabilities
|
||||
# use log sum exp trick to avoid underflow
|
||||
position_logprobs_lse = torch.logsumexp(
|
||||
position_logprobs_tensor, dim=-1, keepdim=True
|
||||
)
|
||||
teacher_probs_t_online = torch.exp(
|
||||
position_logprobs_tensor - position_logprobs_lse
|
||||
)
|
||||
|
||||
# Normalize probabilities (sum to 1)
|
||||
# This is important if the top-k from server aren't a full distribution
|
||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t_online_sum.item() > 1e-9:
|
||||
teacher_probs_t_online = (
|
||||
teacher_probs_t_online / teacher_probs_t_online_sum
|
||||
)
|
||||
else:
|
||||
# If sum is zero, create uniform distribution to avoid NaN/Inf later
|
||||
# This can happen if all raw_logprobs are -inf
|
||||
if self.kd_online_topk > 0:
|
||||
teacher_probs_t_online = (
|
||||
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
|
||||
)
|
||||
# else: leave as is, will result in -inf logprobs
|
||||
#
|
||||
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
|
||||
# dim=0, keepdim=True
|
||||
# )
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor.tolist()
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.error(
|
||||
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fallback to (padded/truncated) raw logprobs if scaling fails
|
||||
return raw_logprobs
|
||||
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
|
||||
|
||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||
def fetch_online_logprobs_sglang(
|
||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
|
||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||
"""
|
||||
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
||||
@@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
current_target_token_ids.append(
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
scaled_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(scaled_logprobs_for_position)
|
||||
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
@@ -442,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
|
||||
# current_target_logprobs.append(normalized_logprobs_for_position)
|
||||
# don't normalize for now as the probs seem to sum to 1.0 already
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
|
||||
@@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||
LigerFusedLinearDistillationBase,
|
||||
)
|
||||
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
"""
|
||||
@@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Top-K KL divergence loss for a chunk.
|
||||
@@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
0.0 for Forward KL (P_teacher || P_student).
|
||||
1.0 for Reverse KL (P_student || P_teacher).
|
||||
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||
normalize_topk: Whether to normalize the log probabilities
|
||||
Returns:
|
||||
Sum of KL divergence losses for the chunk.
|
||||
"""
|
||||
topk = target_token_ids_chunk.shape[-1]
|
||||
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||
student_logits_temp_scaled.float()
|
||||
)
|
||||
@@ -56,6 +61,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
student_logits_topk_temp_scaled - student_lse
|
||||
)
|
||||
|
||||
# we have the top-k student logprobs, normalize them
|
||||
if normalize_topk:
|
||||
student_logprobs_topk_temp_scaled = normalize_logprobs(
|
||||
student_logprobs_topk_temp_scaled, topk
|
||||
)
|
||||
|
||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||
|
||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||
@@ -67,7 +78,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
# Student probabilities P_student from log P_student
|
||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||
|
||||
kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||
|
||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||
@@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||
if beta < 1.0: # Contribution from Forward KL
|
||||
if beta == 0.0: # Contribution from Forward KL
|
||||
fwd_kl_per_token = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_topk_valid
|
||||
)
|
||||
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
|
||||
if beta > 0.0: # Contribution from Reverse KL
|
||||
kd_loss = fwd_kl_per_token.sum()
|
||||
elif beta == 1.0: # Contribution from Reverse KL
|
||||
rev_kl_per_token = student_probs_topk_valid * (
|
||||
student_logprobs_topk_valid - target_logprobs_valid
|
||||
)
|
||||
kd_loss_per_token += beta * rev_kl_per_token
|
||||
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
kd_loss = rev_kl_per_token.sum()
|
||||
else:
|
||||
# JSD - Jensen-Shannon Divergence / Symmetric
|
||||
mean_probs = (
|
||||
1 - beta
|
||||
) * student_probs_topk_valid + beta * teacher_probs_valid
|
||||
log_mean_probs = mean_probs.log()
|
||||
student_kl = F.kl_div(
|
||||
log_mean_probs,
|
||||
student_logprobs_topk_valid,
|
||||
reduction="sum",
|
||||
log_target=True,
|
||||
)
|
||||
teacher_kl = F.kl_div(
|
||||
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
|
||||
)
|
||||
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||
kd_loss = jsd_loss
|
||||
|
||||
return kd_loss
|
||||
|
||||
@@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
compute_ce_loss: bool = True,
|
||||
temperature: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
# Compute student logits for the chunk from hidden states and LM head
|
||||
# student_input_chunk: [chunk_size, hidden_dim]
|
||||
@@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
target_logprobs_chunk,
|
||||
target_mask_chunk,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
return soft_loss, ce_loss
|
||||
@@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
compiled: bool = False,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||
@@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
compute_ce_loss=compute_ce_loss,
|
||||
temperature=temperature,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
def accumulate_chunk_grads(
|
||||
@@ -311,7 +341,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
|
||||
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
|
||||
# we still need to scale the kd_loss by the temp^2
|
||||
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||
@@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
compiled: bool = True,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||
@@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
self.compiled = compiled
|
||||
self.chunk_size = chunk_size
|
||||
self.compute_ce_loss = compute_ce_loss
|
||||
self.normalize_topk = normalize_topk
|
||||
|
||||
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
print(
|
||||
@@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
self.compiled,
|
||||
self.chunk_size,
|
||||
self.compute_ce_loss,
|
||||
self.normalize_topk,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
self.args.kd_temperature,
|
||||
self.args.kd_beta,
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
|
||||
39
src/axolotl/integrations/kd/utils.py
Normal file
39
src/axolotl/integrations/kd/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Helper KD utils"""
|
||||
|
||||
import torch
|
||||
from torch import FloatTensor
|
||||
|
||||
|
||||
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||
"""
|
||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||
"""
|
||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||
if logprobs.shape[-1] != topk:
|
||||
# pad last dimension of logprobs to match topk length with -inf
|
||||
padding_len = topk - logprobs.shape[-1]
|
||||
padding_tensor = torch.full(
|
||||
(
|
||||
*logprobs.shape[:-1],
|
||||
padding_len,
|
||||
), # Takes all dimensions of logprobs except the last, then appends padding_needed
|
||||
float("-inf"),
|
||||
dtype=logprobs.dtype,
|
||||
device=logprobs.device,
|
||||
)
|
||||
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
|
||||
|
||||
# Convert logprobs at T_online to probabilities
|
||||
# use log sum exp trick to avoid underflow
|
||||
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
|
||||
|
||||
# Normalize probabilities (sum to 1)
|
||||
# This is important if the top-k from server aren't a full distribution
|
||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
||||
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor
|
||||
Reference in New Issue
Block a user