From f11227a35a6093cbf852e7c4820a5da3391f551c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 30 Jan 2025 10:39:18 -0500 Subject: [PATCH] various fixes --- src/axolotl/integrations/kd/chat_template.py | 13 ++++--- .../kd/topk_logprob/forward_kl.py | 39 ++++++++++++------- src/axolotl/integrations/kd/trainer.py | 8 ++-- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 5a7e4f40d..699728e9f 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -66,13 +66,18 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) input_padding_len = input_seq_len - target_seq_len - # get non-zero top-k + # get non-zero top-k (prune None logprobs from vllm data step) top_k_vals = [ len(logprobs[i]) for i in range(len(logprobs)) if logprobs[i] is not None and len(logprobs[i]) ] - top_k = max(set(top_k_vals), key=top_k_vals.count) + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + target_logprobs = [] target_token_ids = [] target_mask = [] @@ -98,10 +103,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) - # for _ in range(target_seq_len): - # # TODO also check against sample["labels"] - # target_mask.append([1] * top_k) - for position in range(input_padding_len, input_seq_len): if sample["labels"][position] == -100: target_mask.append([0] * top_k) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index a61011ad5..ea263da2a 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -61,6 +61,19 @@ def loss( ) -> torch.Tensor: """ A KD loss function that is TorchScript-friendly. + + Arguments: + student_logits (torch.Tensor): The logits of the student model. + Shape: [B, student_seq_len, vocab_size] + target_token_ids (torch.Tensor): The top-k teacher/target token IDs + Shape: [B, teacher_seq_len, top_k] + target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized. + Shape: [B, teacher_seq_len, top_k] + target_mask (torch.Tensor): The mask for valid tokens. + Shape: [B, teacher_seq_len, top_k] + num_items_in_batch (int, optional): The number of items in the batch. + kd_temperature (float, optional): The temperature for KD. + Default: 1.0 """ target_logprobs = target_logprobs.float() @@ -122,9 +135,9 @@ def loss( def topk_kd_loss_with_zscore( student_logits: torch.Tensor, # [B, seq_len, vocab_size] - teacher_topk_ids: torch.Tensor, # [B, seq_len, K] - teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space - teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space + target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] kd_temperature: float = 1.0, # classic KD temperature zscore_base_temp: float = 1.0, # from the paper num_items_in_batch: int = -1, @@ -134,15 +147,15 @@ def topk_kd_loss_with_zscore( from "Logit Standardization in Knowledge Distillation". """ - teacher_topk_logprobs = teacher_topk_logprobs.float() + target_logprobs = target_logprobs.float() - B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name + B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name # 1) Gather the student's top-k logits to match teacher student_logits_for_kd = student_logits[ :, :teacher_seq_len, : ] # [B, seq_len, vocab] student_topk_logits = torch.gather( - student_logits_for_kd, dim=-1, index=teacher_topk_ids + student_logits_for_kd, dim=-1, index=target_token_ids ) # [B, seq_len, K] student_topk_logits = student_topk_logits.float() @@ -154,18 +167,18 @@ def topk_kd_loss_with_zscore( # 3) Convert teacher logprobs -> treat them as “logits” for z-score # (They differ by +some_constant from real logits, but in z-score # that constant is subtracted out anyway.) - teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity + teacher_logits_for_zscore = target_logprobs # rename variable for clarity # 4) Z-score teacher and student - # If teacher_mask is 2D, expand to 3D for the K dimension - if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len): - teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K) + # If target_mask is 2D, expand to 3D for the K dimension + if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): + target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) teacher_z = zscore_standardize( - teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp + teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp ) student_z = zscore_standardize( - student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp + student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp ) # 5) Convert to log-probs for KL @@ -173,7 +186,7 @@ def topk_kd_loss_with_zscore( student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) # 6) Restrict to valid tokens if needed - valid_mask = teacher_mask.bool() # shape [B, seq_len, K] + valid_mask = target_mask.bool() # shape [B, seq_len, K] teacher_probs_z = teacher_logprobs_z.exp() teacher_probs_z = teacher_probs_z[valid_mask] teacher_logprobs_z = teacher_logprobs_z[valid_mask] diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 9d99a7e1d..47599344c 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -76,10 +76,10 @@ class AxolotlKDTrainer(AxolotlTrainer): if self.args.kd_zscore_base_temp: loss_kd = topk_kd_loss_with_zscore( - student_logits=shift_logits, - teacher_topk_ids=target_token_ids_for_loss, - teacher_topk_logprobs=target_logprobs_for_loss, - teacher_mask=target_mask_for_loss, + shift_logits, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, kd_temperature=self.args.kd_temperature, zscore_base_temp=self.args.kd_zscore_base_temp, num_items_in_batch=num_items_in_batch,