various fixes
This commit is contained in:
@@ -66,13 +66,18 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
@@ -98,10 +103,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# for _ in range(target_seq_len):
|
||||
# # TODO also check against sample["labels"]
|
||||
# target_mask.append([1] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
@@ -61,6 +61,19 @@ def loss(
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
|
||||
Arguments:
|
||||
student_logits (torch.Tensor): The logits of the student model.
|
||||
Shape: [B, student_seq_len, vocab_size]
|
||||
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_mask (torch.Tensor): The mask for valid tokens.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
@@ -122,9 +135,9 @@ def loss(
|
||||
|
||||
def topk_kd_loss_with_zscore(
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
teacher_topk_ids: torch.Tensor, # [B, seq_len, K]
|
||||
teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
kd_temperature: float = 1.0, # classic KD temperature
|
||||
zscore_base_temp: float = 1.0, # from the paper
|
||||
num_items_in_batch: int = -1,
|
||||
@@ -134,15 +147,15 @@ def topk_kd_loss_with_zscore(
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
"""
|
||||
|
||||
teacher_topk_logprobs = teacher_topk_logprobs.float()
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
|
||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab]
|
||||
student_topk_logits = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=teacher_topk_ids
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, seq_len, K]
|
||||
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
@@ -154,18 +167,18 @@ def topk_kd_loss_with_zscore(
|
||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||
# (They differ by +some_constant from real logits, but in z-score
|
||||
# that constant is subtracted out anyway.)
|
||||
teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity
|
||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||
|
||||
# 4) Z-score teacher and student
|
||||
# If teacher_mask is 2D, expand to 3D for the K dimension
|
||||
if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len):
|
||||
teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
# If target_mask is 2D, expand to 3D for the K dimension
|
||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
|
||||
teacher_z = zscore_standardize(
|
||||
teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp
|
||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
student_z = zscore_standardize(
|
||||
student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp
|
||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
|
||||
# 5) Convert to log-probs for KL
|
||||
@@ -173,7 +186,7 @@ def topk_kd_loss_with_zscore(
|
||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||
|
||||
# 6) Restrict to valid tokens if needed
|
||||
valid_mask = teacher_mask.bool() # shape [B, seq_len, K]
|
||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||
teacher_probs_z = teacher_logprobs_z.exp()
|
||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||
|
||||
@@ -76,10 +76,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
student_logits=shift_logits,
|
||||
teacher_topk_ids=target_token_ids_for_loss,
|
||||
teacher_topk_logprobs=target_logprobs_for_loss,
|
||||
teacher_mask=target_mask_for_loss,
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
|
||||
Reference in New Issue
Block a user