From 8bb871b5cf0810fd4034069821250d718db366ca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 20 Oct 2025 14:06:58 +0700 Subject: [PATCH] fix: deepspeed with context parallel (#3220) --- .../monkeypatch/transformers/trainer_context_parallel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py index 74a35e83f..ba8b16dda 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py +++ b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py @@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' -PATCHED_GUARD = ( - 'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):' -) +PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):' def patch_prepare_context_parallel_inputs() -> None: