From 9640aacfc9a24035a1ca16a03ef2baf7b3e30ed0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 16 Apr 2025 03:47:51 +0000 Subject: [PATCH] fixes for batch API funcs, simplify --- src/axolotl/monkeypatch/attention/ring_attn/patch.py | 3 +-- src/axolotl/utils/schemas/config.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index 160e57052..b5587ddca 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -76,8 +76,7 @@ def register_ring_attn( LOG.info( "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs " - f"using the {ring_attn_func.value} ring-flash-attn implementation" + f"each sequence will be processed across {sequence_parallel_degree} GPUs" ) rank = dist.get_rank() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f68d160df..9463df6d9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -19,6 +19,7 @@ from pydantic import ( from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.distributed import is_main_process +from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -260,7 +261,7 @@ class AxolotlInputConfig( sequence_parallel_degree: int | None = None heads_k_stride: int | None = None - ring_attn_func: str | None = None + ring_attn_func: RingAttnFunc | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -1196,8 +1197,6 @@ class AxolotlInputConfig( if getattr(self, "sequence_parallel_degree", 1) == 1: return self - from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - if self.ring_attn_func is not None: valid_funcs = list(RingAttnFunc) if self.ring_attn_func in valid_funcs: