diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9463df6d9..d3f7ae887 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -18,8 +18,8 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available -from axolotl.utils.distributed import is_main_process from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc +from axolotl.utils.distributed import is_main_process from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -261,7 +261,7 @@ class AxolotlInputConfig( sequence_parallel_degree: int | None = None heads_k_stride: int | None = None - ring_attn_func: RingAttnFunc | None = None + ring_attn_func: str | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None