fix: num_items_in_batch wrong type in kd trainer loss

This commit is contained in:
NanoCode012
2025-05-20 16:56:24 +07:00
parent a27b909c5c
commit 348409c2ff

View File

@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if num_items_in_batch is None:
num_items_in_batch = -1
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,