finish basic impl; change naming from SP -> CP to match torch
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for sequence parallelism"""
|
||||
"""E2E tests for context parallelism"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
class TestContextParallelism:
|
||||
"""Test case for training with context parallelism enabled"""
|
||||
|
||||
def _run_sequence_parallel_test(
|
||||
def _run_context_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
"""Helper method to run context parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
def test_context_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
"""Test context parallel training with different configurations"""
|
||||
self._run_context_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
|
||||
@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
Reference in New Issue
Block a user