fix tests
This commit is contained in:
committed by
Dan Saunders
parent
22cfa42961
commit
ab3b36339a
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -11,13 +11,6 @@ from accelerate.logging import get_logger
|
|||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
try:
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
|
||||||
except ImportError:
|
|
||||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
|
||||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
|
||||||
pass
|
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -91,4 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Main Axolotl input configuration Pydantic models"""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def test_sequence_parallel_slicing(
|
|||||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||||
def test_config_validation_with_valid_inputs(cfg):
|
def test_config_validation_with_valid_inputs(cfg):
|
||||||
"""Test that valid sequence parallelism configurations pass validation."""
|
"""Test that valid sequence parallelism configurations pass validation."""
|
||||||
# Import the actual model class with appropriate mocks
|
# Import the actual model class with appropriate mocks
|
||||||
|
|||||||
Reference in New Issue
Block a user