fixes for batch API funcs, simplify
This commit is contained in:
@@ -76,8 +76,7 @@ def register_ring_attn(
|
|||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
"Enabling ring attention sequence parallelism: "
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs "
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
f"using the {ring_attn_func.value} ring-flash-attn implementation"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from pydantic import (
|
|||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
@@ -260,7 +261,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
sequence_parallel_degree: int | None = None
|
||||||
heads_k_stride: int | None = None
|
heads_k_stride: int | None = None
|
||||||
ring_attn_func: str | None = None
|
ring_attn_func: RingAttnFunc | None = None
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
@@ -1196,8 +1197,6 @@ class AxolotlInputConfig(
|
|||||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
|
||||||
|
|
||||||
if self.ring_attn_func is not None:
|
if self.ring_attn_func is not None:
|
||||||
valid_funcs = list(RingAttnFunc)
|
valid_funcs = list(RingAttnFunc)
|
||||||
if self.ring_attn_func in valid_funcs:
|
if self.ring_attn_func in valid_funcs:
|
||||||
|
|||||||
Reference in New Issue
Block a user