make it work
This commit is contained in:
@@ -13,46 +13,52 @@ def kd_loss_function(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
|
||||
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
|
||||
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
|
||||
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
|
||||
|
||||
# Determine the teacher sequence length
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
# Slice the student logits to match the teacher-provided seq length
|
||||
# Slice student logits to match the teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, -teacher_seq_len:, :
|
||||
] # Now [B, teacher_seq_len, vocab_size]
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Convert student top-K logits to logprobs
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
)
|
||||
) # [B, seq_len, K]
|
||||
|
||||
# teacher_probs are simply exp of teacher_logprobs (already scaled)
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Prune tensors to only keep valid tokens
|
||||
# This will result in 1D arrays of only valid positions
|
||||
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
|
||||
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
|
||||
|
||||
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
|
||||
teacher_probs = target_logprobs.exp()
|
||||
|
||||
# Compute forward KL
|
||||
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
|
||||
kd_loss_per_position = (
|
||||
teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
).sum(
|
||||
dim=-1
|
||||
) # [B, teacher_seq_len]
|
||||
# Compute forward KL:
|
||||
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# gradient accumulation fixes
|
||||
if num_items_in_batch:
|
||||
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
|
||||
# Normalize by number of items or mean over valid tokens
|
||||
if num_items_in_batch is not None:
|
||||
# If you know how many items should be considered in the batch
|
||||
kd_loss = kd_loss / num_items_in_batch
|
||||
else:
|
||||
kd_loss = kd_loss_per_position.mean() # Scalar
|
||||
# Otherwise, just average over all valid tokens
|
||||
kd_loss = kd_loss / kd_loss_per_token.size(0)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@@ -70,6 +76,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
columns_to_add.append("target_logprobs")
|
||||
if "target_token_ids" not in self._signature_columns:
|
||||
columns_to_add.append("target_token_ids")
|
||||
if "target_mask" not in self._signature_columns:
|
||||
columns_to_add.append("target_mask")
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
@@ -83,6 +91,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
@@ -96,6 +105,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
|
||||
@@ -490,8 +490,23 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
padding_len = input_seq_len - target_seq_len
|
||||
top_k = len(logprobs[0])
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
for _ in range(padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for _ in range(target_seq_len):
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
@@ -519,6 +534,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# Apply temperature scaling at data load time
|
||||
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
||||
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
||||
# normalize to probabilities so they sum up to 1
|
||||
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
||||
position_logprobs_tensor, dim=0, keepdim=True
|
||||
)
|
||||
@@ -531,6 +547,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
@@ -538,7 +555,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
return self.transform_logprobs(tokenized_prompt)
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""
|
||||
DataCollator for axolotl to handle KD fields
|
||||
DataCollator for axolotl to handle KD fields without using -inf for padding,
|
||||
and with a teacher_mask to identify padded positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -17,6 +18,9 @@ from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
|
||||
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Data collator for KD, including handling KD-specific fields.
|
||||
|
||||
This version avoids using -inf and instead uses a large negative value for padding
|
||||
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
@@ -32,7 +36,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
# Extract labels and position_ids first (as in original code)
|
||||
padding_side = self.tokenizer.padding_side
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
("labels", self.label_pad_token_id),
|
||||
("position_ids", self.position_pad_token_id),
|
||||
@@ -46,7 +52,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
// self.pad_to_multiple_of
|
||||
) * self.pad_to_multiple_of
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
||||
if isinstance(f[feature_name], list):
|
||||
@@ -69,63 +74,77 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
# Handle target_logprobs and target_token_ids manually
|
||||
target_logprobs_list = []
|
||||
target_token_ids_list = []
|
||||
target_mask_list = []
|
||||
has_teacher_data = ("target_logprobs" in features[0]) and (
|
||||
"target_token_ids" in features[0]
|
||||
)
|
||||
|
||||
if has_teacher_data:
|
||||
# Extract these fields
|
||||
# Extract and remove from features
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
target_logprobs_list.append(f.pop("target_logprobs"))
|
||||
target_token_ids_list.append(f.pop("target_token_ids"))
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths to pad
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
# Pad target_logprobs and target_token_ids
|
||||
padded_target_logprobs = []
|
||||
padded_target_token_ids = []
|
||||
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
|
||||
# Pad seq dimension
|
||||
padded_teacher_mask_list = []
|
||||
|
||||
for t_logprobs, t_ids, t_mask in zip(
|
||||
target_logprobs_list, target_token_ids_list, target_mask_list
|
||||
):
|
||||
t_logprobs_padded = []
|
||||
t_ids_padded = []
|
||||
for i in range( # pylint: disable=consider-using-enumerate
|
||||
len(t_logprobs)
|
||||
t_mask_padded = []
|
||||
|
||||
for lp, ids, mask in zip( # pylint: disable=invalid-name
|
||||
t_logprobs, t_ids, t_mask
|
||||
):
|
||||
lp = t_logprobs[i] # pylint: disable=invalid-name
|
||||
ids = t_ids[i]
|
||||
# Pad K dimension
|
||||
lp_len = len(lp)
|
||||
if lp_len < max_k:
|
||||
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
|
||||
max_k - lp_len
|
||||
) # or some pad value that won't break exp()
|
||||
ids = ids + [0] * (max_k - lp_len)
|
||||
# Use -1e9 for padding logprobs and 0 for token_ids
|
||||
pad_len = max_k - lp_len
|
||||
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
|
||||
ids = ids + [0] * pad_len
|
||||
mask = mask + [0] * pad_len
|
||||
else:
|
||||
lp = lp[:max_k] # pylint: disable=invalid-name
|
||||
ids = ids[:max_k]
|
||||
mask = mask[:max_k]
|
||||
|
||||
t_logprobs_padded.append(lp)
|
||||
t_ids_padded.append(ids)
|
||||
t_mask_padded.append(mask)
|
||||
|
||||
# If sequence is shorter than max_teacher_seq_len
|
||||
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
||||
if seq_len_diff > 0:
|
||||
# Pad sequences fully if needed
|
||||
t_logprobs_padded.extend(
|
||||
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
|
||||
[[-1e9] * max_k for _ in range(seq_len_diff)]
|
||||
)
|
||||
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
|
||||
padded_target_logprobs.append(t_logprobs_padded)
|
||||
padded_target_token_ids.append(t_ids_padded)
|
||||
padded_teacher_mask_list.append(t_mask_padded)
|
||||
|
||||
# Convert to tensors
|
||||
padded_target_logprobs = torch.tensor(
|
||||
padded_target_logprobs, dtype=torch.float
|
||||
)
|
||||
# We can store token_ids as long tensor
|
||||
padded_target_token_ids = torch.tensor(
|
||||
padded_target_token_ids, dtype=torch.long
|
||||
)
|
||||
padded_teacher_mask_list = torch.tensor(
|
||||
padded_teacher_mask_list, dtype=torch.int
|
||||
)
|
||||
|
||||
# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
|
||||
# Pad using tokenizer for regular fields
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
@@ -134,12 +153,13 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# Add back the teacher data if it exists
|
||||
# Add back teacher data if present
|
||||
if has_teacher_data:
|
||||
features["target_logprobs"] = padded_target_logprobs
|
||||
features["target_token_ids"] = padded_target_token_ids
|
||||
features["target_mask"] = padded_teacher_mask_list
|
||||
|
||||
# Prepare decoder_input_ids if applicable
|
||||
# Prepare decoder_input_ids if the model supports it
|
||||
if (
|
||||
"labels" in features
|
||||
and self.model is not None
|
||||
|
||||
Reference in New Issue
Block a user