change up import to prevent AttributeError (#1863)
* change up import to prevent AttributeError * tweak patching check for updated upstream
This commit is contained in:
@@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import (
|
|||||||
|
|
||||||
|
|
||||||
def hijack_llama_prepare_4d_mask():
|
def hijack_llama_prepare_4d_mask():
|
||||||
import transformers.modeling_attn_mask_utils
|
from transformers import modeling_attn_mask_utils
|
||||||
import transformers.models.llama.modeling_llama
|
from transformers.models.llama import modeling_llama
|
||||||
|
|
||||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
)
|
)
|
||||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||||
patched_prepare_4d_causal_attention_mask
|
patched_prepare_4d_causal_attention_mask
|
||||||
)
|
)
|
||||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||||
patched_prepare_4d_causal_attention_mask
|
patched_prepare_4d_causal_attention_mask
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
@@ -29,8 +28,7 @@ ORIGINAL_CEL_CODE = """ if labels is not None:
|
|||||||
loss = loss_fct(shift_logits, shift_labels)
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_CEL_CODE = """ if labels is not None:
|
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
loss = fast_cross_entropy_loss(
|
loss = fast_cross_entropy_loss(
|
||||||
logits = shift_logits,
|
logits = shift_logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user