precommit
This commit is contained in:
@@ -175,7 +175,7 @@ def test_sequence_parallel_slicing(
|
|||||||
def test_config_validation_with_valid_inputs(cfg):
|
def test_config_validation_with_valid_inputs(cfg):
|
||||||
"""Test that valid sequence parallelism configurations pass validation."""
|
"""Test that valid sequence parallelism configurations pass validation."""
|
||||||
# Import the actual model class with appropriate mocks
|
# Import the actual model class with appropriate mocks
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||||
cfg = cfg | {
|
cfg = cfg | {
|
||||||
@@ -191,7 +191,7 @@ def test_config_validation_with_valid_inputs(cfg):
|
|||||||
|
|
||||||
def test_config_validation_with_invalid_inputs(cfg):
|
def test_config_validation_with_invalid_inputs(cfg):
|
||||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||||
cfg = cfg | {
|
cfg = cfg | {
|
||||||
|
|||||||
Reference in New Issue
Block a user