Compare commits
1 Commits
fix/kd-tra
...
axolotl-ci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9e5e22e6b |
@@ -74,9 +74,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
|
||||||
if num_items_in_batch is None:
|
|
||||||
num_items_in_batch = -1
|
|
||||||
|
|
||||||
if self.args.kd_zscore_base_temp:
|
if self.args.kd_zscore_base_temp:
|
||||||
loss_kd = topk_kd_loss_with_zscore(
|
loss_kd = topk_kd_loss_with_zscore(
|
||||||
shift_logits,
|
shift_logits,
|
||||||
|
|||||||
@@ -58,11 +58,15 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
with hf_offline_context(True):
|
with hf_offline_context(True):
|
||||||
try:
|
try:
|
||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(
|
||||||
|
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
|
||||||
|
)
|
||||||
except LocalEntryNotFoundError:
|
except LocalEntryNotFoundError:
|
||||||
pass
|
pass
|
||||||
with hf_offline_context(False):
|
with hf_offline_context(False):
|
||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(
|
||||||
|
*args, user_agent={"is_ci": "true", "axolotl": "ci"}, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user