various fixes

This commit is contained in:
Wing Lian
2025-01-30 10:39:18 -05:00
parent c434951dd6
commit f11227a35a
3 changed files with 37 additions and 23 deletions

View File

@@ -66,13 +66,18 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [ top_k_vals = [
len(logprobs[i]) len(logprobs[i])
for i in range(len(logprobs)) for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i]) if logprobs[i] is not None and len(logprobs[i])
] ]
top_k = max(set(top_k_vals), key=top_k_vals.count) max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = [] target_logprobs = []
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
@@ -98,10 +103,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
# for _ in range(target_seq_len):
# # TODO also check against sample["labels"]
# target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len): for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100: if sample["labels"][position] == -100:
target_mask.append([0] * top_k) target_mask.append([0] * top_k)

View File

@@ -61,6 +61,19 @@ def loss(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
A KD loss function that is TorchScript-friendly. A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
""" """
target_logprobs = target_logprobs.float() target_logprobs = target_logprobs.float()
@@ -122,9 +135,9 @@ def loss(
def topk_kd_loss_with_zscore( def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size] student_logits: torch.Tensor, # [B, seq_len, vocab_size]
teacher_topk_ids: torch.Tensor, # [B, seq_len, K] target_token_ids: torch.Tensor, # [B, seq_len, K]
teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1, num_items_in_batch: int = -1,
@@ -134,15 +147,15 @@ def topk_kd_loss_with_zscore(
from "Logit Standardization in Knowledge Distillation". from "Logit Standardization in Knowledge Distillation".
""" """
teacher_topk_logprobs = teacher_topk_logprobs.float() target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher # 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[ student_logits_for_kd = student_logits[
:, :teacher_seq_len, : :, :teacher_seq_len, :
] # [B, seq_len, vocab] ] # [B, seq_len, vocab]
student_topk_logits = torch.gather( student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=teacher_topk_ids student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K] ) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float() student_topk_logits = student_topk_logits.float()
@@ -154,18 +167,18 @@ def topk_kd_loss_with_zscore(
# 3) Convert teacher logprobs -> treat them as “logits” for z-score # 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score # (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.) # that constant is subtracted out anyway.)
teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student # 4) Z-score teacher and student
# If teacher_mask is 2D, expand to 3D for the K dimension # If target_mask is 2D, expand to 3D for the K dimension
if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len): if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K) target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize( teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
) )
student_z = zscore_standardize( student_z = zscore_standardize(
student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
) )
# 5) Convert to log-probs for KL # 5) Convert to log-probs for KL
@@ -173,7 +186,7 @@ def topk_kd_loss_with_zscore(
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed # 6) Restrict to valid tokens if needed
valid_mask = teacher_mask.bool() # shape [B, seq_len, K] valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp() teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask] teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask] teacher_logprobs_z = teacher_logprobs_z[valid_mask]

View File

@@ -76,10 +76,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
if self.args.kd_zscore_base_temp: if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore( loss_kd = topk_kd_loss_with_zscore(
student_logits=shift_logits, shift_logits,
teacher_topk_ids=target_token_ids_for_loss, target_token_ids_for_loss,
teacher_topk_logprobs=target_logprobs_for_loss, target_logprobs_for_loss,
teacher_mask=target_mask_for_loss, target_mask_for_loss,
kd_temperature=self.args.kd_temperature, kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp, zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,