diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index de63869c7..167b8cdbd 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -209,7 +209,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # We want to produce a single "merged" feature dict for each sub-batch. out_features = [{} for _ in features] - for i, sub_features in enumerate(features): + for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks + features + ): # sub_features is a list of dicts, each dict = one sequence’s features # We'll merge them into out_features[i]. # @@ -243,10 +245,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # For example, input_ids or labels are often arrays. arrays = [] for feat in sub_features: - if field_name in feat: + if field_name in feat and isinstance( + feat[field_name], (list, torch.Tensor) + ): + if isinstance( + feat[field_name][0], (dict, str) + ): # pylint: disable=too-many-nested-blocks + continue arr = np.array(feat[field_name]) arrays.append(arr) - out_features[i][field_name] = np.concatenate(arrays) + if arrays: + out_features[i][field_name] = np.concatenate(arrays) # 3) Now call the parent collator, which will do: # - padding of labels/position_ids diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 06bce6971..823d28ac2 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -258,6 +258,7 @@ class ChunkedTopKKDLoss(nn.Module): target_logprobs: torch.Tensor, # [B, seq_len, K] target_mask: torch.Tensor, # [B, seq_len, K] num_items_in_batch: int = -1, # optional batch size for normalization + top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits ) -> torch.Tensor: # 1. Split along the "token" dimension (dim=1). @@ -284,6 +285,7 @@ class ChunkedTopKKDLoss(nn.Module): target_mask=msk_chunk, num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens kd_temperature=self.kd_temperature, + top_k_before_softmax=top_k_before_softmax, ) # kd_loss returns an average over the chunk's valid tokens.