From ab3b36339ab5901e7e53fcd9905292711e4a2177 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 20 Mar 2025 12:04:22 -0400 Subject: [PATCH] fix tests --- cicd/Dockerfile.jinja | 4 ++-- src/axolotl/monkeypatch/attention/ring_attn.py | 9 ++------- src/axolotl/utils/schemas/config.py | 3 +-- tests/e2e/patched/test_sp.py | 1 + 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index b212a0065..6988e092b 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN pip install packaging==23.2 setuptools==75.8.0 RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 9ed332dfa..95c44a820 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -11,13 +11,6 @@ from accelerate.logging import get_logger from axolotl.logging_config import configure_logging -try: - from ring_flash_attn import substitute_hf_flash_attn -except ImportError: - # We pass silently here, but raise an ImportError in our Axolotl config validation - # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. - pass - configure_logging() LOG = get_logger(__name__) @@ -91,4 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int): if rank == 0: LOG.info(f"Sequence parallel group assignments: {group_assignments}") + from ring_flash_attn import substitute_hf_flash_attn + substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 847d4e510..463e957ce 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1,5 +1,4 @@ -"""Main Axolotl input configuration Pydantic models""" - +"""Module with Pydantic models for configuration.""" # pylint: disable=too-many-lines import logging diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 48d264ece..7dd0e152d 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -170,6 +170,7 @@ def test_sequence_parallel_slicing( assert torch.all(result["input_ids"] == expected_input_ids) +@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()}) def test_config_validation_with_valid_inputs(cfg): """Test that valid sequence parallelism configurations pass validation.""" # Import the actual model class with appropriate mocks