make it work

This commit is contained in:
Wing Lian
2024-12-19 00:28:02 -05:00
parent e2aba41939
commit 1ea225129f
3 changed files with 93 additions and 44 deletions

View File

@@ -13,46 +13,52 @@ def kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
**kwargs, # pylint: disable=unused-argument
):
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
# Determine the teacher sequence length
teacher_seq_len = target_token_ids.shape[1]
# Slice the student logits to match the teacher-provided seq length
# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, -teacher_seq_len:, :
] # Now [B, teacher_seq_len, vocab_size]
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Convert student top-K logits to logprobs
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
)
) # [B, seq_len, K]
# teacher_probs are simply exp of teacher_logprobs (already scaled)
# Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool()
# Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
teacher_probs = target_logprobs.exp()
# Compute forward KL
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
kd_loss_per_position = (
teacher_probs * (target_logprobs - student_logprobs_topk)
).sum(
dim=-1
) # [B, teacher_seq_len]
# Compute forward KL:
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# gradient accumulation fixes
if num_items_in_batch:
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
kd_loss = kd_loss / num_items_in_batch
else:
kd_loss = kd_loss_per_position.mean() # Scalar
# Otherwise, just average over all valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0)
return kd_loss
@@ -70,6 +76,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if "target_mask" not in self._signature_columns:
columns_to_add.append("target_mask")
if columns_to_add:
self._signature_columns += columns_to_add
@@ -83,6 +91,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -96,6 +105,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch=num_items_in_batch,
)

View File

@@ -490,8 +490,23 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
@@ -519,6 +534,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# Apply temperature scaling at data load time
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
position_logprobs_tensor = position_logprobs_tensor / self.temperature
# normalize to probabilities so they sum up to 1
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
position_logprobs_tensor, dim=0, keepdim=True
)
@@ -531,6 +547,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
@@ -538,7 +555,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
return self.transform_logprobs(tokenized_prompt)
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):

View File

@@ -1,5 +1,6 @@
"""
DataCollator for axolotl to handle KD fields
DataCollator for axolotl to handle KD fields without using -inf for padding,
and with a teacher_mask to identify padded positions.
"""
from dataclasses import dataclass
@@ -17,6 +18,9 @@ from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
class DataCollatorForKD(DataCollatorForSeq2Seq):
"""
Data collator for KD, including handling KD-specific fields.
This version avoids using -inf and instead uses a large negative value for padding
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
"""
tokenizer: PreTrainedTokenizerBase
@@ -32,7 +36,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
if return_tensors is None:
return_tensors = self.return_tensors
# Extract labels and position_ids first (as in original code)
padding_side = self.tokenizer.padding_side
# Pad labels and position_ids first
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
@@ -46,7 +52,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
// self.pad_to_multiple_of
) * self.pad_to_multiple_of
padding_side = self.tokenizer.padding_side
for f in features: # pylint: disable=invalid-name
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
@@ -69,63 +74,77 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
# Handle target_logprobs and target_token_ids manually
target_logprobs_list = []
target_token_ids_list = []
target_mask_list = []
has_teacher_data = ("target_logprobs" in features[0]) and (
"target_token_ids" in features[0]
)
if has_teacher_data:
# Extract these fields
# Extract and remove from features
for f in features: # pylint: disable=invalid-name
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths to pad
# Determine max lengths
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
# Pad target_logprobs and target_token_ids
padded_target_logprobs = []
padded_target_token_ids = []
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
# Pad seq dimension
padded_teacher_mask_list = []
for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list
):
t_logprobs_padded = []
t_ids_padded = []
for i in range( # pylint: disable=consider-using-enumerate
len(t_logprobs)
t_mask_padded = []
for lp, ids, mask in zip( # pylint: disable=invalid-name
t_logprobs, t_ids, t_mask
):
lp = t_logprobs[i] # pylint: disable=invalid-name
ids = t_ids[i]
# Pad K dimension
lp_len = len(lp)
if lp_len < max_k:
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
max_k - lp_len
) # or some pad value that won't break exp()
ids = ids + [0] * (max_k - lp_len)
# Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
else:
lp = lp[:max_k] # pylint: disable=invalid-name
ids = ids[:max_k]
mask = mask[:max_k]
t_logprobs_padded.append(lp)
t_ids_padded.append(ids)
t_mask_padded.append(mask)
# If sequence is shorter than max_teacher_seq_len
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
if seq_len_diff > 0:
# Pad sequences fully if needed
t_logprobs_padded.extend(
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
[[-1e9] * max_k for _ in range(seq_len_diff)]
)
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
padded_target_logprobs.append(t_logprobs_padded)
padded_target_token_ids.append(t_ids_padded)
padded_teacher_mask_list.append(t_mask_padded)
# Convert to tensors
padded_target_logprobs = torch.tensor(
padded_target_logprobs, dtype=torch.float
)
# We can store token_ids as long tensor
padded_target_token_ids = torch.tensor(
padded_target_token_ids, dtype=torch.long
)
padded_teacher_mask_list = torch.tensor(
padded_teacher_mask_list, dtype=torch.int
)
# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
# Pad using tokenizer for regular fields
features = self.tokenizer.pad(
features,
padding=self.padding,
@@ -134,12 +153,13 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
return_tensors=return_tensors,
)
# Add back the teacher data if it exists
# Add back teacher data if present
if has_teacher_data:
features["target_logprobs"] = padded_target_logprobs
features["target_token_ids"] = padded_target_token_ids
features["target_mask"] = padded_teacher_mask_list
# Prepare decoder_input_ids if applicable
# Prepare decoder_input_ids if the model supports it
if (
"labels" in features
and self.model is not None