fix tests

This commit is contained in:
Dan Saunders
2025-03-20 12:04:22 -04:00
committed by Dan Saunders
parent 22cfa42961
commit ab3b36339a
4 changed files with 6 additions and 11 deletions

View File

@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -11,13 +11,6 @@ from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import substitute_hf_flash_attn
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging()
LOG = get_logger(__name__)
@@ -91,4 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int):
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -1,5 +1,4 @@
"""Main Axolotl input configuration Pydantic models"""
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging

View File

@@ -170,6 +170,7 @@ def test_sequence_parallel_slicing(
assert torch.all(result["input_ids"] == expected_input_ids)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks