From c7b1db329edf0f22e680582f6a7d68a3d459d88a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 May 2025 21:12:19 -0400 Subject: [PATCH] logsumexp trick: --- .../kd/collator_online_teacher.py | 39 +++++++++++++------ src/axolotl/integrations/kd/kernels/models.py | 12 ------ 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 4593401a3..1a0173760 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -1,3 +1,4 @@ +import pandas as pd import requests import logging from typing import List, Optional, Dict, Any @@ -58,7 +59,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) # Convert logprobs at T_online to probabilities - teacher_probs_t_online = position_logprobs_tensor.exp() + # use log sum exp trick to avoid underflow + position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True) + teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse) # Normalize probabilities (sum to 1) # This is important if the top-k from server aren't a full distribution @@ -75,10 +78,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum( # dim=0, keepdim=True # ) - # TODO Convert back to logprobs using log sum exp trick - teacher_probs_t_online_max = torch.logsumexp(teacher_probs_t_online, dim=-1, keepdim=True) - final_logprobs_tensor = teacher_probs_t_online - teacher_probs_t_online_max - # final_logprobs_tensor = torch.log(teacher_probs_t_online) + final_logprobs_tensor = torch.log(teacher_probs_t_online) return final_logprobs_tensor.tolist() @@ -259,6 +259,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): return ret_logprobs_data for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels): + # seq_input_ids: List[int] + # seq_labels: List[int] + current_target_logprobs = [] current_target_token_ids = [] current_target_mask = [] @@ -289,13 +292,19 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") input_top_logprobs = [] # Treat as empty + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + # generate a hash over seq_input_ids and convert it to an int + hash_input_ids: int = hash(tuple(seq_input_ids)) + for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: # this is always the case for the first token. # there is never logprob data for the first token since that's a true input # so we replace the None value with padding data current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append([0] * self.kd_online_topk) + current_target_token_ids.append(list(range(self.kd_online_topk))) current_target_mask.append([0] * self.kd_online_topk) elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] @@ -305,7 +314,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str] LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append([0] * self.kd_online_topk) + current_target_token_ids.append(list(range(self.kd_online_topk))) current_target_mask.append([0] * self.kd_online_topk) continue @@ -318,7 +327,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # Ensure correct length (top_k) if len(pos_logprobs_raw) < self.kd_online_topk: pad_len = self.kd_online_topk - len(pos_logprobs_raw) - LOG.debug(f"Padding position {i} with {pad_len} top-k tokens and logprobs.") + LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.") pos_logprobs_raw.extend([-float("inf")] * pad_len) pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id @@ -339,13 +348,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): else: # Pad if no logprobs for this position (either due to length mismatch or None entry) current_target_logprobs.append([-float("inf")] * self.kd_online_topk) - current_target_token_ids.append([0] * self.kd_online_topk) + current_target_token_ids.append(list(range(self.kd_online_topk))) current_target_mask.append([0] * self.kd_online_topk) ret_logprobs_data["target_token_ids"].append(current_target_token_ids) ret_logprobs_data["target_logprobs"].append(current_target_logprobs) ret_logprobs_data["target_mask"].append(current_target_mask) + with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f: + pd.DataFrame(current_target_logprobs).to_parquet(f, index=False) + with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: + pd.DataFrame(current_target_token_ids).to_parquet(f, index=False) + except requests.exceptions.RequestException as e: LOG.error(f"Error fetching logprobs from online teacher: {e}") raise e @@ -403,11 +417,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" # Each value is a list, corresponding to items_for_api_call for i, item_to_update in enumerate(items_for_api_call): - # Check if API call was successful and returned data for this item. - # fetch_online_logprobs returns dict with empty lists if API fails or malformed. # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. if api_responses_for_sub_batch and \ - i < len(api_responses_for_sub_batch.get("target_token_ids", [])): # Check bounds + i < len(api_responses_for_sub_batch): # Check bounds + assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"]) + assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"]) + assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"]) item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i] item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i] item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i] diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index 9006d92b0..f346ad21a 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -88,18 +88,6 @@ def kldiv_forward_llama_like( ) -def apply_kernel_to_qwen2(): - from transformers.models.qwen2 import modeling_qwen2 - - modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_llama_like - - -def apply_kernel_to_llama(): - from transformers.models.llama import modeling_llama - - modeling_llama.LlamaForCausalLM.forward = kldiv_forward_llama_like - - def apply_kernel(model_type): # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}"