finish basic impl; change naming from SP -> CP to match torch

This commit is contained in:
Dan Saunders
2025-06-13 09:51:06 -04:00
parent aced809989
commit 7a88de4fa8
25 changed files with 525 additions and 488 deletions

View File

@@ -1,4 +1,4 @@
"""E2E tests for sequence parallelism"""
"""E2E tests for context parallelism"""
from pathlib import Path
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
class TestContextParallelism:
"""Test case for training with context parallelism enabled"""
def _run_sequence_parallel_test(
def _run_context_parallel_test(
self,
temp_dir,
sample_packing=True,
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
"""Helper method to run context parallel tests with different configurations"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(
def test_context_parallel_training(
self,
temp_dir,
sample_packing,
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
"""Test context parallel training with different configurations"""
self._run_context_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,

View File

@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {