fixes for batch API funcs, simplify

This commit is contained in:
Dan Saunders
2025-04-16 03:47:51 +00:00
parent 5306c6acbb
commit 9640aacfc9
2 changed files with 3 additions and 5 deletions

View File

@@ -76,8 +76,7 @@ def register_ring_attn(
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs "
f"using the {ring_attn_func.value} ring-flash-attn implementation"
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
rank = dist.get_rank()

View File

@@ -19,6 +19,7 @@ from pydantic import (
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
@@ -260,7 +261,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: str | None = None
ring_attn_func: RingAttnFunc | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1196,8 +1197,6 @@ class AxolotlInputConfig(
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs: