change up import to prevent AttributeError (#1863)
* change up import to prevent AttributeError * tweak patching check for updated upstream
This commit is contained in:
@@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import (
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
import transformers.modeling_attn_mask_utils
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers import modeling_attn_mask_utils
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
|
||||
@@ -16,8 +16,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
|
||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
@@ -29,8 +28,7 @@ ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
"""
|
||||
|
||||
PATCHED_CEL_CODE = """ if labels is not None:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
|
||||
Reference in New Issue
Block a user