Compare commits

..

1 Commits

Author SHA1 Message Date
NanoCode012
348409c2ff fix: num_items_in_batch wrong type in kd trainer loss 2025-05-20 16:56:24 +07:00
2 changed files with 5 additions and 6 deletions

View File

@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if num_items_in_batch is None:
num_items_in_batch = -1
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,

View File

@@ -58,15 +58,11 @@ def snapshot_download_w_retry(*args, **kwargs):
"""
with hf_offline_context(True):
try:
return snapshot_download(
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
)
return snapshot_download(*args, **kwargs)
except LocalEntryNotFoundError:
pass
with hf_offline_context(False):
return snapshot_download(
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
)
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)