update config.qmd and rename option

This commit is contained in:
Dan Saunders
2025-03-13 23:13:37 +00:00
parent 345a9dd831
commit 919b88f11b
11 changed files with 58 additions and 54 deletions

View File

@@ -25,7 +25,7 @@ def test_integration_with_config():
],
"load_in_8bit": False,
"sequence_len": 1024,
"sequence_parallel_size": 2,
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
@@ -58,17 +58,17 @@ def test_integration_with_config():
normalize_config(cfg)
# Verify sequence parallelism settings were properly processed
assert cfg.sequence_parallel_size == 2
assert cfg.sequence_parallel_degree == 2
assert cfg.flash_attention is True
# Check if the sequence_parallel_size was propagated to the training args
# Check if the sequence_parallel_degree was propagated to the training args
from axolotl.core.training_args import AxolotlTrainingArguments
# pylint: disable=unexpected-keyword-arg
training_args = AxolotlTrainingArguments(
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
)
assert training_args.sequence_parallel_size == 2
assert training_args.sequence_parallel_degree == 2
def test_ring_attn_group_creation():
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
# Register with sequence parallel size of 2
register_ring_attn(sequence_parallel_size=2)
register_ring_attn(sequence_parallel_degree=2)
# Get the ring attention group
group = get_ring_attn_group()