better handling to drop string fields for kd with raw dataset

This commit is contained in:
Wing Lian
2025-05-20 08:49:23 -07:00
parent 83ad248e5b
commit 0399aefcb3
2 changed files with 14 additions and 3 deletions

View File

@@ -209,7 +209,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate(features):
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
features
):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
@@ -243,10 +245,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat:
if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor)
):
if isinstance(
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
continue
arr = np.array(feat[field_name])
arrays.append(arr)
out_features[i][field_name] = np.concatenate(arrays)
if arrays:
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids

View File

@@ -258,6 +258,7 @@ class ChunkedTopKKDLoss(nn.Module):
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits
) -> torch.Tensor:
# 1. Split along the "token" dimension (dim=1).
@@ -284,6 +285,7 @@ class ChunkedTopKKDLoss(nn.Module):
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
top_k_before_softmax=top_k_before_softmax,
)
# kd_loss returns an average over the chunk's valid tokens.