fix tests
This commit is contained in:
committed by
Dan Saunders
parent
22cfa42961
commit
ab3b36339a
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -11,13 +11,6 @@ from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
try:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -91,4 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int):
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Main Axolotl input configuration Pydantic models"""
|
||||
|
||||
"""Module with Pydantic models for configuration."""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
|
||||
@@ -170,6 +170,7 @@ def test_sequence_parallel_slicing(
|
||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||
|
||||
|
||||
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||
def test_config_validation_with_valid_inputs(cfg):
|
||||
"""Test that valid sequence parallelism configurations pass validation."""
|
||||
# Import the actual model class with appropriate mocks
|
||||
|
||||
Reference in New Issue
Block a user