Compare commits

...

1 Commits

Author SHA1 Message Date
NanoCode012
348409c2ff fix: num_items_in_batch wrong type in kd trainer loss 2025-05-20 16:56:24 +07:00

View File

@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if num_items_in_batch is None:
num_items_in_batch = -1
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,