progress (messy :O)
This commit is contained in:
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
@@ -328,7 +328,7 @@ class TestApplySequenceParallelism:
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
@@ -347,7 +347,7 @@ class TestApplySequenceParallelism:
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
@@ -374,7 +374,7 @@ class TestApplySequenceParallelism:
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
@@ -440,7 +440,7 @@ class TestApplySequenceParallelism:
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
apply_context_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -466,7 +466,7 @@ class TestApplySequenceParallelism:
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
|
||||
Reference in New Issue
Block a user