various fixes

This commit is contained in:
Wing Lian
2025-01-30 10:39:18 -05:00
parent c434951dd6
commit f11227a35a
3 changed files with 37 additions and 23 deletions

View File

@@ -66,13 +66,18 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
top_k = max(set(top_k_vals), key=top_k_vals.count)
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
@@ -98,10 +103,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# for _ in range(target_seq_len):
# # TODO also check against sample["labels"]
# target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)

View File

@@ -61,6 +61,19 @@ def loss(
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
"""
target_logprobs = target_logprobs.float()
@@ -122,9 +135,9 @@ def loss(
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
teacher_topk_ids: torch.Tensor, # [B, seq_len, K]
teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
@@ -134,15 +147,15 @@ def topk_kd_loss_with_zscore(
from "Logit Standardization in Knowledge Distillation".
"""
teacher_topk_logprobs = teacher_topk_logprobs.float()
target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=teacher_topk_ids
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
@@ -154,18 +167,18 @@ def topk_kd_loss_with_zscore(
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If teacher_mask is 2D, expand to 3D for the K dimension
if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len):
teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K)
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
@@ -173,7 +186,7 @@ def topk_kd_loss_with_zscore(
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = teacher_mask.bool() # shape [B, seq_len, K]
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]

View File

@@ -76,10 +76,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
student_logits=shift_logits,
teacher_topk_ids=target_token_ids_for_loss,
teacher_topk_logprobs=target_logprobs_for_loss,
teacher_mask=target_mask_for_loss,
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,