change up import to prevent AttributeError (#1863)

* change up import to prevent AttributeError

* tweak patching check for updated upstream
This commit is contained in:
Wing Lian
2024-08-23 17:00:01 -04:00
committed by GitHub
parent 810ecd4e81
commit 77a4b9cda2
2 changed files with 8 additions and 10 deletions

View File

@@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import (
def hijack_llama_prepare_4d_mask(): def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils from transformers import modeling_attn_mask_utils
import transformers.models.llama.modeling_llama from transformers.models.llama import modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa patched_prepare_4d_causal_attention_mask_for_sdpa
) )
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa patched_prepare_4d_causal_attention_mask_for_sdpa
) )
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask patched_prepare_4d_causal_attention_mask
) )
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask patched_prepare_4d_causal_attention_mask
) )

View File

@@ -16,8 +16,7 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None: ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
@@ -29,8 +28,7 @@ ORIGINAL_CEL_CODE = """ if labels is not None:
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
""" """
PATCHED_CEL_CODE = """ if labels is not None: PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss( loss = fast_cross_entropy_loss(
logits = shift_logits, logits = shift_logits,