diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..c493e025a 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer): target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous() + if num_items_in_batch is None: + num_items_in_batch = -1 + if self.args.kd_zscore_base_temp: loss_kd = topk_kd_loss_with_zscore( shift_logits,