make it work
This commit is contained in:
@@ -13,46 +13,52 @@ def kd_loss_function(
|
|||||||
student_logits,
|
student_logits,
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
target_logprobs,
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[int] = None,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
):
|
||||||
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
|
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
|
||||||
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
|
|
||||||
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
|
|
||||||
|
|
||||||
|
# Determine the teacher sequence length
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
teacher_seq_len = target_token_ids.shape[1]
|
||||||
|
|
||||||
# Slice the student logits to match the teacher-provided seq length
|
# Slice student logits to match the teacher-provided sequence length
|
||||||
student_logits_for_kd = student_logits[
|
student_logits_for_kd = student_logits[
|
||||||
:, -teacher_seq_len:, :
|
:, -teacher_seq_len:, :
|
||||||
] # Now [B, teacher_seq_len, vocab_size]
|
] # [B, teacher_seq_len, vocab_size]
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
# Gather student logits for teacher's top-K tokens
|
||||||
student_logits_topk = torch.gather(
|
student_logits_topk = torch.gather(
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
) # [B, teacher_seq_len, K]
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
# Convert student top-K logits to logprobs
|
# Convert student top-k logits to logprobs
|
||||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||||
student_logits_topk, dim=-1, keepdim=True
|
student_logits_topk, dim=-1, keepdim=True
|
||||||
)
|
) # [B, seq_len, K]
|
||||||
|
|
||||||
# teacher_probs are simply exp of teacher_logprobs (already scaled)
|
# Convert teacher_mask to boolean for indexing
|
||||||
|
valid_mask = target_mask.bool()
|
||||||
|
|
||||||
|
# Prune tensors to only keep valid tokens
|
||||||
|
# This will result in 1D arrays of only valid positions
|
||||||
|
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
|
||||||
|
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
|
||||||
|
|
||||||
|
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
|
||||||
teacher_probs = target_logprobs.exp()
|
teacher_probs = target_logprobs.exp()
|
||||||
|
|
||||||
# Compute forward KL
|
# Compute forward KL:
|
||||||
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
|
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
|
||||||
kd_loss_per_position = (
|
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||||
teacher_probs * (target_logprobs - student_logprobs_topk)
|
kd_loss = kd_loss_per_token.sum()
|
||||||
).sum(
|
|
||||||
dim=-1
|
|
||||||
) # [B, teacher_seq_len]
|
|
||||||
|
|
||||||
# gradient accumulation fixes
|
# Normalize by number of items or mean over valid tokens
|
||||||
if num_items_in_batch:
|
if num_items_in_batch is not None:
|
||||||
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
|
# If you know how many items should be considered in the batch
|
||||||
|
kd_loss = kd_loss / num_items_in_batch
|
||||||
else:
|
else:
|
||||||
kd_loss = kd_loss_per_position.mean() # Scalar
|
# Otherwise, just average over all valid tokens
|
||||||
|
kd_loss = kd_loss / kd_loss_per_token.size(0)
|
||||||
|
|
||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
@@ -70,6 +76,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
columns_to_add.append("target_logprobs")
|
columns_to_add.append("target_logprobs")
|
||||||
if "target_token_ids" not in self._signature_columns:
|
if "target_token_ids" not in self._signature_columns:
|
||||||
columns_to_add.append("target_token_ids")
|
columns_to_add.append("target_token_ids")
|
||||||
|
if "target_mask" not in self._signature_columns:
|
||||||
|
columns_to_add.append("target_mask")
|
||||||
if columns_to_add:
|
if columns_to_add:
|
||||||
self._signature_columns += columns_to_add
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
@@ -83,6 +91,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
"""
|
"""
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
target_logprobs = inputs.pop("target_logprobs")
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
target_token_ids = inputs.pop("target_token_ids")
|
||||||
|
target_mask = inputs.pop("target_mask")
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
@@ -96,6 +105,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
student_logits,
|
student_logits,
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
target_logprobs,
|
target_logprobs,
|
||||||
|
target_mask,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -490,8 +490,23 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
|
|
||||||
def transform_logprobs(self, sample):
|
def transform_logprobs(self, sample):
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
|
target_seq_len = len(logprobs)
|
||||||
|
input_seq_len = len(sample["input_ids"])
|
||||||
|
padding_len = input_seq_len - target_seq_len
|
||||||
|
top_k = len(logprobs[0])
|
||||||
target_logprobs = []
|
target_logprobs = []
|
||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
|
target_mask = []
|
||||||
|
|
||||||
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
|
for _ in range(padding_len):
|
||||||
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
|
target_token_ids.append(list(range(top_k)))
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
|
for _ in range(target_seq_len):
|
||||||
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
@@ -519,6 +534,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# Apply temperature scaling at data load time
|
# Apply temperature scaling at data load time
|
||||||
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
|
||||||
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
position_logprobs_tensor = position_logprobs_tensor / self.temperature
|
||||||
|
# normalize to probabilities so they sum up to 1
|
||||||
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
|
||||||
position_logprobs_tensor, dim=0, keepdim=True
|
position_logprobs_tensor, dim=0, keepdim=True
|
||||||
)
|
)
|
||||||
@@ -531,6 +547,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
sample["target_token_ids"] = target_token_ids
|
sample["target_token_ids"] = target_token_ids
|
||||||
|
sample["target_mask"] = target_mask
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -538,7 +555,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
logprobs = prompt.pop(self.logprobs_field)
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||||
tokenized_prompt[self.logprobs_field] = logprobs
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
return self.transform_logprobs(tokenized_prompt)
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to handle KD fields
|
DataCollator for axolotl to handle KD fields without using -inf for padding,
|
||||||
|
and with a teacher_mask to identify padded positions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -17,6 +18,9 @@ from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
|
|||||||
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
Data collator for KD, including handling KD-specific fields.
|
Data collator for KD, including handling KD-specific fields.
|
||||||
|
|
||||||
|
This version avoids using -inf and instead uses a large negative value for padding
|
||||||
|
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -32,7 +36,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return_tensors = self.return_tensors
|
return_tensors = self.return_tensors
|
||||||
|
|
||||||
# Extract labels and position_ids first (as in original code)
|
padding_side = self.tokenizer.padding_side
|
||||||
|
|
||||||
|
# Pad labels and position_ids first
|
||||||
for feature_name, pad_token_id in [
|
for feature_name, pad_token_id in [
|
||||||
("labels", self.label_pad_token_id),
|
("labels", self.label_pad_token_id),
|
||||||
("position_ids", self.position_pad_token_id),
|
("position_ids", self.position_pad_token_id),
|
||||||
@@ -46,7 +52,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
// self.pad_to_multiple_of
|
// self.pad_to_multiple_of
|
||||||
) * self.pad_to_multiple_of
|
) * self.pad_to_multiple_of
|
||||||
|
|
||||||
padding_side = self.tokenizer.padding_side
|
|
||||||
for f in features: # pylint: disable=invalid-name
|
for f in features: # pylint: disable=invalid-name
|
||||||
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
||||||
if isinstance(f[feature_name], list):
|
if isinstance(f[feature_name], list):
|
||||||
@@ -69,63 +74,77 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
# Handle target_logprobs and target_token_ids manually
|
# Handle target_logprobs and target_token_ids manually
|
||||||
target_logprobs_list = []
|
target_logprobs_list = []
|
||||||
target_token_ids_list = []
|
target_token_ids_list = []
|
||||||
|
target_mask_list = []
|
||||||
has_teacher_data = ("target_logprobs" in features[0]) and (
|
has_teacher_data = ("target_logprobs" in features[0]) and (
|
||||||
"target_token_ids" in features[0]
|
"target_token_ids" in features[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_teacher_data:
|
if has_teacher_data:
|
||||||
# Extract these fields
|
# Extract and remove from features
|
||||||
for f in features: # pylint: disable=invalid-name
|
for f in features: # pylint: disable=invalid-name
|
||||||
target_logprobs_list.append(f.pop("target_logprobs"))
|
target_logprobs_list.append(f.pop("target_logprobs"))
|
||||||
target_token_ids_list.append(f.pop("target_token_ids"))
|
target_token_ids_list.append(f.pop("target_token_ids"))
|
||||||
|
target_mask_list.append(f.pop("target_mask"))
|
||||||
|
|
||||||
# Determine max lengths to pad
|
# Determine max lengths
|
||||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||||
|
|
||||||
# Pad target_logprobs and target_token_ids
|
|
||||||
padded_target_logprobs = []
|
padded_target_logprobs = []
|
||||||
padded_target_token_ids = []
|
padded_target_token_ids = []
|
||||||
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
|
padded_teacher_mask_list = []
|
||||||
# Pad seq dimension
|
|
||||||
|
for t_logprobs, t_ids, t_mask in zip(
|
||||||
|
target_logprobs_list, target_token_ids_list, target_mask_list
|
||||||
|
):
|
||||||
t_logprobs_padded = []
|
t_logprobs_padded = []
|
||||||
t_ids_padded = []
|
t_ids_padded = []
|
||||||
for i in range( # pylint: disable=consider-using-enumerate
|
t_mask_padded = []
|
||||||
len(t_logprobs)
|
|
||||||
|
for lp, ids, mask in zip( # pylint: disable=invalid-name
|
||||||
|
t_logprobs, t_ids, t_mask
|
||||||
):
|
):
|
||||||
lp = t_logprobs[i] # pylint: disable=invalid-name
|
|
||||||
ids = t_ids[i]
|
|
||||||
# Pad K dimension
|
|
||||||
lp_len = len(lp)
|
lp_len = len(lp)
|
||||||
if lp_len < max_k:
|
if lp_len < max_k:
|
||||||
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
|
# Use -1e9 for padding logprobs and 0 for token_ids
|
||||||
max_k - lp_len
|
pad_len = max_k - lp_len
|
||||||
) # or some pad value that won't break exp()
|
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
|
||||||
ids = ids + [0] * (max_k - lp_len)
|
ids = ids + [0] * pad_len
|
||||||
|
mask = mask + [0] * pad_len
|
||||||
|
else:
|
||||||
|
lp = lp[:max_k] # pylint: disable=invalid-name
|
||||||
|
ids = ids[:max_k]
|
||||||
|
mask = mask[:max_k]
|
||||||
|
|
||||||
t_logprobs_padded.append(lp)
|
t_logprobs_padded.append(lp)
|
||||||
t_ids_padded.append(ids)
|
t_ids_padded.append(ids)
|
||||||
|
t_mask_padded.append(mask)
|
||||||
|
|
||||||
# If sequence is shorter than max_teacher_seq_len
|
|
||||||
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
||||||
if seq_len_diff > 0:
|
if seq_len_diff > 0:
|
||||||
|
# Pad sequences fully if needed
|
||||||
t_logprobs_padded.extend(
|
t_logprobs_padded.extend(
|
||||||
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
|
[[-1e9] * max_k for _ in range(seq_len_diff)]
|
||||||
)
|
)
|
||||||
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||||
|
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||||
|
|
||||||
padded_target_logprobs.append(t_logprobs_padded)
|
padded_target_logprobs.append(t_logprobs_padded)
|
||||||
padded_target_token_ids.append(t_ids_padded)
|
padded_target_token_ids.append(t_ids_padded)
|
||||||
|
padded_teacher_mask_list.append(t_mask_padded)
|
||||||
|
|
||||||
# Convert to tensors
|
# Convert to tensors
|
||||||
padded_target_logprobs = torch.tensor(
|
padded_target_logprobs = torch.tensor(
|
||||||
padded_target_logprobs, dtype=torch.float
|
padded_target_logprobs, dtype=torch.float
|
||||||
)
|
)
|
||||||
# We can store token_ids as long tensor
|
|
||||||
padded_target_token_ids = torch.tensor(
|
padded_target_token_ids = torch.tensor(
|
||||||
padded_target_token_ids, dtype=torch.long
|
padded_target_token_ids, dtype=torch.long
|
||||||
)
|
)
|
||||||
|
padded_teacher_mask_list = torch.tensor(
|
||||||
|
padded_teacher_mask_list, dtype=torch.int
|
||||||
|
)
|
||||||
|
|
||||||
# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
|
# Pad using tokenizer for regular fields
|
||||||
features = self.tokenizer.pad(
|
features = self.tokenizer.pad(
|
||||||
features,
|
features,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
@@ -134,12 +153,13 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add back the teacher data if it exists
|
# Add back teacher data if present
|
||||||
if has_teacher_data:
|
if has_teacher_data:
|
||||||
features["target_logprobs"] = padded_target_logprobs
|
features["target_logprobs"] = padded_target_logprobs
|
||||||
features["target_token_ids"] = padded_target_token_ids
|
features["target_token_ids"] = padded_target_token_ids
|
||||||
|
features["target_mask"] = padded_teacher_mask_list
|
||||||
|
|
||||||
# Prepare decoder_input_ids if applicable
|
# Prepare decoder_input_ids if the model supports it
|
||||||
if (
|
if (
|
||||||
"labels" in features
|
"labels" in features
|
||||||
and self.model is not None
|
and self.model is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user