more KD updates
This commit is contained in:
@@ -49,6 +49,7 @@ class KDPlugin(BasePlugin):
|
|||||||
"kd_alpha": cfg.kd_alpha,
|
"kd_alpha": cfg.kd_alpha,
|
||||||
"kd_temperature": cfg.kd_temperature,
|
"kd_temperature": cfg.kd_temperature,
|
||||||
"kd_beta": cfg.kd_beta,
|
"kd_beta": cfg.kd_beta,
|
||||||
|
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
@@ -72,6 +73,7 @@ class KDPlugin(BasePlugin):
|
|||||||
"kd_temperature": cfg.kd_temperature,
|
"kd_temperature": cfg.kd_temperature,
|
||||||
"kd_online_server": cfg.kd_online_server,
|
"kd_online_server": cfg.kd_online_server,
|
||||||
"kd_online_timeout": cfg.kd_online_timeout,
|
"kd_online_timeout": cfg.kd_online_timeout,
|
||||||
|
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||||
}
|
}
|
||||||
|
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ class KDArgs(BaseModel):
|
|||||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
kd_temperature: float | None = None # temperature for sampling during KD
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
|
kd_normalize_topk: bool | None = (
|
||||||
|
None # whether to normalize student logits during KD
|
||||||
|
)
|
||||||
|
|
||||||
# TODO online kd
|
# TODO online kd
|
||||||
kd_online_server_base_url: str | None = None
|
kd_online_server_base_url: str | None = None
|
||||||
@@ -67,3 +70,6 @@ class KDTrainingArgsMixin:
|
|||||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
kd_temperature: float | None = None # temperature for sampling during KD
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
|
kd_normalize_topk: float | None = (
|
||||||
|
None # whether to normalize student logits during KD
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
from orjson import orjson
|
from orjson import orjson
|
||||||
|
|
||||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -58,6 +59,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
kd_online_server: Optional[str] = "vllm",
|
kd_online_server: Optional[str] = "vllm",
|
||||||
kd_online_timeout: Optional[int] = 120,
|
kd_online_timeout: Optional[int] = 120,
|
||||||
kd_cache_dir: Optional[str] = None,
|
kd_cache_dir: Optional[str] = None,
|
||||||
|
kd_normalize_topk: Optional[bool] = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -78,6 +80,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
self.http_session = requests.Session()
|
self.http_session = requests.Session()
|
||||||
self.kd_online_timeout = kd_online_timeout
|
self.kd_online_timeout = kd_online_timeout
|
||||||
self.kd_cache_dir = kd_cache_dir
|
self.kd_cache_dir = kd_cache_dir
|
||||||
|
self.kd_normalize_topk = kd_normalize_topk
|
||||||
|
|
||||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
@@ -88,70 +91,15 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
|
||||||
if len(raw_logprobs) != self.kd_online_topk:
|
|
||||||
# This case should be rare if pre-padding/truncation is done correctly
|
|
||||||
LOG.warning(
|
|
||||||
f"Logprobs length mismatch in _normalize_logprobs. "
|
|
||||||
f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate."
|
|
||||||
)
|
|
||||||
padded_logprobs = raw_logprobs[: self.kd_online_topk]
|
|
||||||
if len(padded_logprobs) < self.kd_online_topk:
|
|
||||||
padded_logprobs.extend(
|
|
||||||
[-float("inf")] * (self.kd_online_topk - len(padded_logprobs))
|
|
||||||
)
|
|
||||||
raw_logprobs = padded_logprobs
|
|
||||||
|
|
||||||
try:
|
|
||||||
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Convert logprobs at T_online to probabilities
|
|
||||||
# use log sum exp trick to avoid underflow
|
|
||||||
position_logprobs_lse = torch.logsumexp(
|
|
||||||
position_logprobs_tensor, dim=-1, keepdim=True
|
|
||||||
)
|
|
||||||
teacher_probs_t_online = torch.exp(
|
|
||||||
position_logprobs_tensor - position_logprobs_lse
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize probabilities (sum to 1)
|
|
||||||
# This is important if the top-k from server aren't a full distribution
|
|
||||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True)
|
|
||||||
if teacher_probs_t_online_sum.item() > 1e-9:
|
|
||||||
teacher_probs_t_online = (
|
|
||||||
teacher_probs_t_online / teacher_probs_t_online_sum
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If sum is zero, create uniform distribution to avoid NaN/Inf later
|
|
||||||
# This can happen if all raw_logprobs are -inf
|
|
||||||
if self.kd_online_topk > 0:
|
|
||||||
teacher_probs_t_online = (
|
|
||||||
torch.ones_like(teacher_probs_t_online) / self.kd_online_topk
|
|
||||||
)
|
|
||||||
# else: leave as is, will result in -inf logprobs
|
|
||||||
#
|
|
||||||
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
|
|
||||||
# dim=0, keepdim=True
|
|
||||||
# )
|
|
||||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
|
||||||
|
|
||||||
return final_logprobs_tensor.tolist()
|
|
||||||
|
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
|
||||||
LOG.error(
|
|
||||||
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Fallback to (padded/truncated) raw logprobs if scaling fails
|
|
||||||
return raw_logprobs
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
def fetch_online_logprobs_sglang(
|
def fetch_online_logprobs_sglang(
|
||||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
|
||||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||||
"""
|
"""
|
||||||
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
||||||
@@ -267,10 +215,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_token_ids.append(
|
current_target_token_ids.append(
|
||||||
pos_token_ids[: self.kd_online_topk]
|
pos_token_ids[: self.kd_online_topk]
|
||||||
)
|
)
|
||||||
scaled_logprobs_for_position = self._normalize_logprobs(
|
|
||||||
pos_logprobs_raw[: self.kd_online_topk]
|
if self.kd_normalize_topk:
|
||||||
)
|
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||||
current_target_logprobs.append(scaled_logprobs_for_position)
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
current_target_logprobs.append(
|
||||||
|
normalized_logprobs_for_position
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_target_logprobs.append(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
# Mask depends on the corresponding label for the student
|
# Mask depends on the corresponding label for the student
|
||||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
@@ -442,12 +398,17 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
pos_token_ids[: self.kd_online_topk]
|
pos_token_ids[: self.kd_online_topk]
|
||||||
)
|
)
|
||||||
|
|
||||||
# normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk])
|
if self.kd_normalize_topk:
|
||||||
# current_target_logprobs.append(normalized_logprobs_for_position)
|
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||||
# don't normalize for now as the probs seem to sum to 1.0 already
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
current_target_logprobs.append(
|
)
|
||||||
pos_logprobs_raw[: self.kd_online_topk]
|
current_target_logprobs.append(
|
||||||
)
|
normalized_logprobs_for_position
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_target_logprobs.append(
|
||||||
|
pos_logprobs_raw[: self.kd_online_topk]
|
||||||
|
)
|
||||||
|
|
||||||
# Mask depends on the corresponding label for the student
|
# Mask depends on the corresponding label for the student
|
||||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from liger_kernel.chunked_loss.fused_linear_distillation import (
|
|||||||
LigerFusedLinearDistillationBase,
|
LigerFusedLinearDistillationBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||||
|
|
||||||
|
|
||||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||||
"""
|
"""
|
||||||
@@ -21,6 +23,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
beta: float = 0.0,
|
beta: float = 0.0,
|
||||||
|
normalize_topk: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute Top-K KL divergence loss for a chunk.
|
Compute Top-K KL divergence loss for a chunk.
|
||||||
@@ -33,9 +36,11 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
0.0 for Forward KL (P_teacher || P_student).
|
0.0 for Forward KL (P_teacher || P_student).
|
||||||
1.0 for Reverse KL (P_student || P_teacher).
|
1.0 for Reverse KL (P_student || P_teacher).
|
||||||
0.5 for Symmetric KL (average of Forward and Reverse).
|
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||||
|
normalize_topk: Whether to normalize the log probabilities
|
||||||
Returns:
|
Returns:
|
||||||
Sum of KL divergence losses for the chunk.
|
Sum of KL divergence losses for the chunk.
|
||||||
"""
|
"""
|
||||||
|
topk = target_token_ids_chunk.shape[-1]
|
||||||
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||||
student_logits_temp_scaled.float()
|
student_logits_temp_scaled.float()
|
||||||
)
|
)
|
||||||
@@ -56,6 +61,12 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
student_logits_topk_temp_scaled - student_lse
|
student_logits_topk_temp_scaled - student_lse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# we have the top-k student logprobs, normalize them
|
||||||
|
if normalize_topk:
|
||||||
|
student_logprobs_topk_temp_scaled = normalize_logprobs(
|
||||||
|
student_logprobs_topk_temp_scaled, topk
|
||||||
|
)
|
||||||
|
|
||||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||||
|
|
||||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||||
@@ -67,7 +78,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
# Student probabilities P_student from log P_student
|
# Student probabilities P_student from log P_student
|
||||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||||
|
|
||||||
kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||||
|
|
||||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||||
@@ -75,18 +86,33 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||||
if beta < 1.0: # Contribution from Forward KL
|
if beta == 0.0: # Contribution from Forward KL
|
||||||
fwd_kl_per_token = teacher_probs_valid * (
|
fwd_kl_per_token = teacher_probs_valid * (
|
||||||
target_logprobs_valid - student_logprobs_topk_valid
|
target_logprobs_valid - student_logprobs_topk_valid
|
||||||
)
|
)
|
||||||
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
|
kd_loss = fwd_kl_per_token.sum()
|
||||||
if beta > 0.0: # Contribution from Reverse KL
|
elif beta == 1.0: # Contribution from Reverse KL
|
||||||
rev_kl_per_token = student_probs_topk_valid * (
|
rev_kl_per_token = student_probs_topk_valid * (
|
||||||
student_logprobs_topk_valid - target_logprobs_valid
|
student_logprobs_topk_valid - target_logprobs_valid
|
||||||
)
|
)
|
||||||
kd_loss_per_token += beta * rev_kl_per_token
|
kd_loss = rev_kl_per_token.sum()
|
||||||
|
else:
|
||||||
kd_loss = kd_loss_per_token.sum()
|
# JSD - Jensen-Shannon Divergence / Symmetric
|
||||||
|
mean_probs = (
|
||||||
|
1 - beta
|
||||||
|
) * student_probs_topk_valid + beta * teacher_probs_valid
|
||||||
|
log_mean_probs = mean_probs.log()
|
||||||
|
student_kl = F.kl_div(
|
||||||
|
log_mean_probs,
|
||||||
|
student_logprobs_topk_valid,
|
||||||
|
reduction="sum",
|
||||||
|
log_target=True,
|
||||||
|
)
|
||||||
|
teacher_kl = F.kl_div(
|
||||||
|
log_mean_probs, target_logprobs_valid, reduction="sum", log_target=True
|
||||||
|
)
|
||||||
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||||
|
kd_loss = jsd_loss
|
||||||
|
|
||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
@@ -109,6 +135,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
compute_ce_loss: bool = True,
|
compute_ce_loss: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
beta: float = 0.0,
|
beta: float = 0.0,
|
||||||
|
normalize_topk: bool = True,
|
||||||
):
|
):
|
||||||
# Compute student logits for the chunk from hidden states and LM head
|
# Compute student logits for the chunk from hidden states and LM head
|
||||||
# student_input_chunk: [chunk_size, hidden_dim]
|
# student_input_chunk: [chunk_size, hidden_dim]
|
||||||
@@ -144,6 +171,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_logprobs_chunk,
|
target_logprobs_chunk,
|
||||||
target_mask_chunk,
|
target_mask_chunk,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
|
normalize_topk=normalize_topk,
|
||||||
)
|
)
|
||||||
|
|
||||||
return soft_loss, ce_loss
|
return soft_loss, ce_loss
|
||||||
@@ -167,6 +195,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
compiled: bool = False,
|
compiled: bool = False,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
compute_ce_loss: bool = True,
|
compute_ce_loss: bool = True,
|
||||||
|
normalize_topk: bool = True,
|
||||||
):
|
):
|
||||||
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||||
@@ -211,6 +240,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
compute_ce_loss=compute_ce_loss,
|
compute_ce_loss=compute_ce_loss,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
|
normalize_topk=normalize_topk,
|
||||||
)
|
)
|
||||||
|
|
||||||
def accumulate_chunk_grads(
|
def accumulate_chunk_grads(
|
||||||
@@ -311,7 +341,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
ctx.bias_was_none = student_lm_head_bias is None
|
ctx.bias_was_none = student_lm_head_bias is None
|
||||||
ctx.orig_dims = (B, N, D, K)
|
ctx.orig_dims = (B, N, D, K)
|
||||||
|
|
||||||
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
|
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
|
||||||
# we still need to scale the kd_loss by the temp^2
|
# we still need to scale the kd_loss by the temp^2
|
||||||
kd_loss_acc = kd_loss_acc * (temperature**2)
|
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||||
@@ -397,6 +427,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
compiled: bool = True,
|
compiled: bool = True,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
compute_ce_loss: bool = True,
|
compute_ce_loss: bool = True,
|
||||||
|
normalize_topk: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||||
@@ -412,6 +443,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
self.compiled = compiled
|
self.compiled = compiled
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.compute_ce_loss = compute_ce_loss
|
self.compute_ce_loss = compute_ce_loss
|
||||||
|
self.normalize_topk = normalize_topk
|
||||||
|
|
||||||
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||||
print(
|
print(
|
||||||
@@ -449,4 +481,5 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
self.compiled,
|
self.compiled,
|
||||||
self.chunk_size,
|
self.chunk_size,
|
||||||
self.compute_ce_loss,
|
self.compute_ce_loss,
|
||||||
|
self.normalize_topk,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
self.args.kd_temperature,
|
self.args.kd_temperature,
|
||||||
self.args.kd_beta,
|
self.args.kd_beta,
|
||||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||||
|
normalize_topk=self.args.kd_normalize_topk,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
|
|||||||
39
src/axolotl/integrations/kd/utils.py
Normal file
39
src/axolotl/integrations/kd/utils.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Helper KD utils"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||||
|
"""
|
||||||
|
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||||
|
"""
|
||||||
|
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||||
|
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||||
|
if logprobs.shape[-1] != topk:
|
||||||
|
# pad last dimension of logprobs to match topk length with -inf
|
||||||
|
padding_len = topk - logprobs.shape[-1]
|
||||||
|
padding_tensor = torch.full(
|
||||||
|
(
|
||||||
|
*logprobs.shape[:-1],
|
||||||
|
padding_len,
|
||||||
|
), # Takes all dimensions of logprobs except the last, then appends padding_needed
|
||||||
|
float("-inf"),
|
||||||
|
dtype=logprobs.dtype,
|
||||||
|
device=logprobs.device,
|
||||||
|
)
|
||||||
|
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
|
||||||
|
|
||||||
|
# Convert logprobs at T_online to probabilities
|
||||||
|
# use log sum exp trick to avoid underflow
|
||||||
|
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
|
||||||
|
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
|
||||||
|
|
||||||
|
# Normalize probabilities (sum to 1)
|
||||||
|
# This is important if the top-k from server aren't a full distribution
|
||||||
|
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
|
||||||
|
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
||||||
|
|
||||||
|
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||||
|
|
||||||
|
return final_logprobs_tensor
|
||||||
Reference in New Issue
Block a user