various fixes
This commit is contained in:
@@ -66,13 +66,18 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
input_padding_len = input_seq_len - target_seq_len
|
input_padding_len = input_seq_len - target_seq_len
|
||||||
# get non-zero top-k
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
top_k_vals = [
|
top_k_vals = [
|
||||||
len(logprobs[i])
|
len(logprobs[i])
|
||||||
for i in range(len(logprobs))
|
for i in range(len(logprobs))
|
||||||
if logprobs[i] is not None and len(logprobs[i])
|
if logprobs[i] is not None and len(logprobs[i])
|
||||||
]
|
]
|
||||||
top_k = max(set(top_k_vals), key=top_k_vals.count)
|
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
top_k = min(max_top_k, min_top_k)
|
||||||
|
if top_k == 0:
|
||||||
|
raise ValueError("No non-zero top-k logprobs found.")
|
||||||
|
|
||||||
target_logprobs = []
|
target_logprobs = []
|
||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
@@ -98,10 +103,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
# for _ in range(target_seq_len):
|
|
||||||
# # TODO also check against sample["labels"]
|
|
||||||
# target_mask.append([1] * top_k)
|
|
||||||
|
|
||||||
for position in range(input_padding_len, input_seq_len):
|
for position in range(input_padding_len, input_seq_len):
|
||||||
if sample["labels"][position] == -100:
|
if sample["labels"][position] == -100:
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|||||||
@@ -61,6 +61,19 @@ def loss(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A KD loss function that is TorchScript-friendly.
|
A KD loss function that is TorchScript-friendly.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
student_logits (torch.Tensor): The logits of the student model.
|
||||||
|
Shape: [B, student_seq_len, vocab_size]
|
||||||
|
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
target_mask (torch.Tensor): The mask for valid tokens.
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
num_items_in_batch (int, optional): The number of items in the batch.
|
||||||
|
kd_temperature (float, optional): The temperature for KD.
|
||||||
|
Default: 1.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_logprobs = target_logprobs.float()
|
target_logprobs = target_logprobs.float()
|
||||||
@@ -122,9 +135,9 @@ def loss(
|
|||||||
|
|
||||||
def topk_kd_loss_with_zscore(
|
def topk_kd_loss_with_zscore(
|
||||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||||
teacher_topk_ids: torch.Tensor, # [B, seq_len, K]
|
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||||
teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||||
teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||||
kd_temperature: float = 1.0, # classic KD temperature
|
kd_temperature: float = 1.0, # classic KD temperature
|
||||||
zscore_base_temp: float = 1.0, # from the paper
|
zscore_base_temp: float = 1.0, # from the paper
|
||||||
num_items_in_batch: int = -1,
|
num_items_in_batch: int = -1,
|
||||||
@@ -134,15 +147,15 @@ def topk_kd_loss_with_zscore(
|
|||||||
from "Logit Standardization in Knowledge Distillation".
|
from "Logit Standardization in Knowledge Distillation".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
teacher_topk_logprobs = teacher_topk_logprobs.float()
|
target_logprobs = target_logprobs.float()
|
||||||
|
|
||||||
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
|
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||||
# 1) Gather the student's top-k logits to match teacher
|
# 1) Gather the student's top-k logits to match teacher
|
||||||
student_logits_for_kd = student_logits[
|
student_logits_for_kd = student_logits[
|
||||||
:, :teacher_seq_len, :
|
:, :teacher_seq_len, :
|
||||||
] # [B, seq_len, vocab]
|
] # [B, seq_len, vocab]
|
||||||
student_topk_logits = torch.gather(
|
student_topk_logits = torch.gather(
|
||||||
student_logits_for_kd, dim=-1, index=teacher_topk_ids
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
) # [B, seq_len, K]
|
) # [B, seq_len, K]
|
||||||
|
|
||||||
student_topk_logits = student_topk_logits.float()
|
student_topk_logits = student_topk_logits.float()
|
||||||
@@ -154,18 +167,18 @@ def topk_kd_loss_with_zscore(
|
|||||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||||
# (They differ by +some_constant from real logits, but in z-score
|
# (They differ by +some_constant from real logits, but in z-score
|
||||||
# that constant is subtracted out anyway.)
|
# that constant is subtracted out anyway.)
|
||||||
teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity
|
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||||
|
|
||||||
# 4) Z-score teacher and student
|
# 4) Z-score teacher and student
|
||||||
# If teacher_mask is 2D, expand to 3D for the K dimension
|
# If target_mask is 2D, expand to 3D for the K dimension
|
||||||
if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len):
|
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||||
teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K)
|
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||||
|
|
||||||
teacher_z = zscore_standardize(
|
teacher_z = zscore_standardize(
|
||||||
teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp
|
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||||
)
|
)
|
||||||
student_z = zscore_standardize(
|
student_z = zscore_standardize(
|
||||||
student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp
|
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5) Convert to log-probs for KL
|
# 5) Convert to log-probs for KL
|
||||||
@@ -173,7 +186,7 @@ def topk_kd_loss_with_zscore(
|
|||||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||||
|
|
||||||
# 6) Restrict to valid tokens if needed
|
# 6) Restrict to valid tokens if needed
|
||||||
valid_mask = teacher_mask.bool() # shape [B, seq_len, K]
|
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||||
teacher_probs_z = teacher_logprobs_z.exp()
|
teacher_probs_z = teacher_logprobs_z.exp()
|
||||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||||
|
|||||||
@@ -76,10 +76,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
if self.args.kd_zscore_base_temp:
|
if self.args.kd_zscore_base_temp:
|
||||||
loss_kd = topk_kd_loss_with_zscore(
|
loss_kd = topk_kd_loss_with_zscore(
|
||||||
student_logits=shift_logits,
|
shift_logits,
|
||||||
teacher_topk_ids=target_token_ids_for_loss,
|
target_token_ids_for_loss,
|
||||||
teacher_topk_logprobs=target_logprobs_for_loss,
|
target_logprobs_for_loss,
|
||||||
teacher_mask=target_mask_for_loss,
|
target_mask_for_loss,
|
||||||
kd_temperature=self.args.kd_temperature,
|
kd_temperature=self.args.kd_temperature,
|
||||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
|||||||
Reference in New Issue
Block a user