update config.qmd and rename option
This commit is contained in:
@@ -25,7 +25,7 @@ def test_integration_with_config():
|
||||
],
|
||||
"load_in_8bit": False,
|
||||
"sequence_len": 1024,
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
@@ -58,17 +58,17 @@ def test_integration_with_config():
|
||||
normalize_config(cfg)
|
||||
|
||||
# Verify sequence parallelism settings were properly processed
|
||||
assert cfg.sequence_parallel_size == 2
|
||||
assert cfg.sequence_parallel_degree == 2
|
||||
assert cfg.flash_attention is True
|
||||
|
||||
# Check if the sequence_parallel_size was propagated to the training args
|
||||
# Check if the sequence_parallel_degree was propagated to the training args
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
training_args = AxolotlTrainingArguments(
|
||||
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
|
||||
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
|
||||
)
|
||||
assert training_args.sequence_parallel_size == 2
|
||||
assert training_args.sequence_parallel_degree == 2
|
||||
|
||||
|
||||
def test_ring_attn_group_creation():
|
||||
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
|
||||
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
||||
|
||||
# Register with sequence parallel size of 2
|
||||
register_ring_attn(sequence_parallel_size=2)
|
||||
register_ring_attn(sequence_parallel_degree=2)
|
||||
|
||||
# Get the ring attention group
|
||||
group = get_ring_attn_group()
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_size=4)
|
||||
register_ring_attn(sequence_parallel_degree=4)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
@@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg):
|
||||
# Import the actual model class with appropriate mocks
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
|
||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
}
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
assert config.sequence_parallel_size == 2
|
||||
assert config.sequence_parallel_degree == 2
|
||||
assert config.flash_attention is True
|
||||
|
||||
|
||||
@@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg):
|
||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
|
||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": False,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user