logsumexp trick:
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import pandas as pd
|
||||
import requests
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
@@ -58,7 +59,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||
|
||||
# Convert logprobs at T_online to probabilities
|
||||
teacher_probs_t_online = position_logprobs_tensor.exp()
|
||||
# use log sum exp trick to avoid underflow
|
||||
position_logprobs_lse = torch.logsumexp(position_logprobs_tensor, dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = torch.exp(position_logprobs_tensor - position_logprobs_lse)
|
||||
|
||||
# Normalize probabilities (sum to 1)
|
||||
# This is important if the top-k from server aren't a full distribution
|
||||
@@ -75,10 +78,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
# teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum(
|
||||
# dim=0, keepdim=True
|
||||
# )
|
||||
# TODO Convert back to logprobs using log sum exp trick
|
||||
teacher_probs_t_online_max = torch.logsumexp(teacher_probs_t_online, dim=-1, keepdim=True)
|
||||
final_logprobs_tensor = teacher_probs_t_online - teacher_probs_t_online_max
|
||||
# final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor.tolist()
|
||||
|
||||
@@ -259,6 +259,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
return ret_logprobs_data
|
||||
|
||||
for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels):
|
||||
# seq_input_ids: List[int]
|
||||
# seq_labels: List[int]
|
||||
|
||||
current_target_logprobs = []
|
||||
current_target_token_ids = []
|
||||
current_target_mask = []
|
||||
@@ -289,13 +292,19 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.")
|
||||
input_top_logprobs = [] # Treat as empty
|
||||
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
# generate a hash over seq_input_ids and convert it to an int
|
||||
hash_input_ids: int = hash(tuple(seq_input_ids))
|
||||
|
||||
for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
# so we replace the None value with padding data
|
||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None:
|
||||
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
|
||||
@@ -305,7 +314,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str]
|
||||
LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.")
|
||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
|
||||
@@ -318,7 +327,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
# Ensure correct length (top_k)
|
||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||
LOG.debug(f"Padding position {i} with {pad_len} top-k tokens and logprobs.")
|
||||
LOG.warning(f"Padding position {i} with {pad_len} top-k tokens and logprobs.")
|
||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||
|
||||
@@ -339,13 +348,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
else:
|
||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||
current_target_logprobs.append([-float("inf")] * self.kd_online_topk)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
||||
|
||||
with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
||||
pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
||||
with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
||||
pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
raise e
|
||||
@@ -403,11 +417,12 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
||||
# Each value is a list, corresponding to items_for_api_call
|
||||
for i, item_to_update in enumerate(items_for_api_call):
|
||||
# Check if API call was successful and returned data for this item.
|
||||
# fetch_online_logprobs returns dict with empty lists if API fails or malformed.
|
||||
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
||||
if api_responses_for_sub_batch and \
|
||||
i < len(api_responses_for_sub_batch.get("target_token_ids", [])): # Check bounds
|
||||
i < len(api_responses_for_sub_batch): # Check bounds
|
||||
assert len(api_responses_for_sub_batch["target_token_ids"][i]) == len(item_to_update["input_ids"])
|
||||
assert len(api_responses_for_sub_batch["target_logprobs"][i]) == len(item_to_update["input_ids"])
|
||||
assert len(api_responses_for_sub_batch["target_mask"][i]) == len(item_to_update["labels"])
|
||||
item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i]
|
||||
item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i]
|
||||
item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i]
|
||||
|
||||
@@ -88,18 +88,6 @@ def kldiv_forward_llama_like(
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel_to_qwen2():
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_llama_like
|
||||
|
||||
|
||||
def apply_kernel_to_llama():
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
modeling_llama.LlamaForCausalLM.forward = kldiv_forward_llama_like
|
||||
|
||||
|
||||
def apply_kernel(model_type):
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
|
||||
Reference in New Issue
Block a user