progress (messy :O)

This commit is contained in:
Dan Saunders
2025-06-12 18:54:41 +00:00
parent ae73123eae
commit aced809989
8 changed files with 333 additions and 199 deletions

View File

@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@@ -328,7 +328,7 @@ class TestApplySequenceParallelism:
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
@@ -347,7 +347,7 @@ class TestApplySequenceParallelism:
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
@@ -374,7 +374,7 @@ class TestApplySequenceParallelism:
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
@@ -440,7 +440,7 @@ class TestApplySequenceParallelism:
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
apply_context_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
@@ -466,7 +466,7 @@ class TestApplySequenceParallelism:
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,