From 77a4b9cda21deabf97515fb04788b4eec7ed7783 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 17:00:01 -0400 Subject: [PATCH] change up import to prevent AttributeError (#1863) * change up import to prevent AttributeError * tweak patching check for updated upstream --- src/axolotl/monkeypatch/llama_patch_multipack.py | 12 ++++++------ src/axolotl/monkeypatch/unsloth_.py | 6 ++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index 540c5577a..cfd525367 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): - import transformers.modeling_attn_mask_utils - import transformers.models.llama.modeling_llama + from transformers import modeling_attn_mask_utils + from transformers.models.llama import modeling_llama - transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 5b1f0061d..3d42ad17f 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -16,8 +16,7 @@ from transformers.models.llama.modeling_llama import ( LOG = get_logger("axolotl.monkeypatch.unsloth") -ORIGINAL_CEL_CODE = """ if labels is not None: - # Shift so that tokens < n predict n +ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens @@ -29,8 +28,7 @@ ORIGINAL_CEL_CODE = """ if labels is not None: loss = loss_fct(shift_logits, shift_labels) """ -PATCHED_CEL_CODE = """ if labels is not None: - shift_logits = logits[..., :-1, :].contiguous() +PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = fast_cross_entropy_loss( logits = shift_logits,