fix: deepspeed with context parallel (#3220)
This commit is contained in:
@@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||||
PATCHED_GUARD = (
|
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_context_parallel_inputs() -> None:
|
def patch_prepare_context_parallel_inputs() -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user