fix: num_items_in_batch wrong type in kd trainer loss
This commit is contained in:
@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if num_items_in_batch is None:
|
||||
num_items_in_batch = -1
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
|
||||
Reference in New Issue
Block a user