This commit is contained in:
Dan Saunders
2025-04-25 02:28:38 +00:00
parent 6810f0ee19
commit 926dc4af90
6 changed files with 594 additions and 532 deletions

View File

@@ -1009,6 +1009,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.seed:
training_args_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
import functools
import logging
from contextlib import contextmanager
import torch
import torch.distributed as dist
@@ -14,14 +13,66 @@ from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = logging.getLogger(__name__)
def _handle_logits_to_keep(
logits_to_keep,
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
total_seq_len: int,
):
"""
Handle logits_to_keep parameter for sequence parallelism.
Args:
logits_to_keep: Integer or tensor indicating which positions to compute logits
for.
local_rank: Rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: Ring attention function being used.
total_seq_len: Full sequence length.
Returns:
Adjusted logits_to_keep appropriate for this rank's sharded sequence
"""
print("start of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep)
# No transformation needed if logits_to_keep is None
if logits_to_keep is None:
return None
assert isinstance(
logits_to_keep, int
), "sequence parallelism currently only supports integer logits_to_keep"
assert ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
# For standard sharding, each rank gets a contiguous chunk
chunk_size = total_seq_len // local_world_size
start_idx = local_rank * chunk_size
end_idx = start_idx + chunk_size
# Check if logits_to_keep is in this rank's range
if start_idx <= logits_to_keep < end_idx:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep - start_idx)
return logits_to_keep - start_idx
else:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), -1)
return -1
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
@@ -32,10 +83,10 @@ def apply_sequence_parallelism(
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: The ring attention function to use.
Returns:
Sliced batch dictionary.
@@ -48,12 +99,10 @@ def apply_sequence_parallelism(
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
@@ -78,6 +127,14 @@ def apply_sequence_parallelism(
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
if key == "logits_to_keep":
batch[key] = _handle_logits_to_keep(
logits_to_keep=batch[key],
local_rank=local_rank,
local_world_size=local_world_size,
ring_attn_func=ring_attn_func,
total_seq_len=total_seq_len,
)
return batch
@@ -205,8 +262,11 @@ class SequenceParallelContextManager:
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
print("start of sequence_parallel_post_hook")
# Gather the sharded outputs
return self.gather_outputs(output)
output = self.gather_outputs(output)
print("end of sequence_parallel_post_hook")
return output
# Register both hooks
self.hook_handles.append(

View File

@@ -18,7 +18,6 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import (
DatasetConfig,

View File

@@ -1,6 +1,4 @@
"""
E2E tests for mixtral
"""
"""E2E tests for mixtral"""
import logging
import os
@@ -99,6 +97,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,12 +12,12 @@ from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
@pytest.fixture