This commit is contained in:
Dan Saunders
2025-04-25 02:28:38 +00:00
parent 6810f0ee19
commit 926dc4af90
6 changed files with 594 additions and 532 deletions

View File

@@ -1009,6 +1009,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["dataloader_prefetch_factor"] = ( training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor self.cfg.dataloader_prefetch_factor
) )
if self.cfg.seed:
training_args_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = ( training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing self.cfg.gradient_checkpointing

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
import functools import functools
import logging import logging
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -14,14 +13,66 @@ from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
update_ring_attn_params, update_ring_attn_params,
) )
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def _handle_logits_to_keep(
logits_to_keep,
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
total_seq_len: int,
):
"""
Handle logits_to_keep parameter for sequence parallelism.
Args:
logits_to_keep: Integer or tensor indicating which positions to compute logits
for.
local_rank: Rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: Ring attention function being used.
total_seq_len: Full sequence length.
Returns:
Adjusted logits_to_keep appropriate for this rank's sharded sequence
"""
print("start of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep)
# No transformation needed if logits_to_keep is None
if logits_to_keep is None:
return None
assert isinstance(
logits_to_keep, int
), "sequence parallelism currently only supports integer logits_to_keep"
assert ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
# For standard sharding, each rank gets a contiguous chunk
chunk_size = total_seq_len // local_world_size
start_idx = local_rank * chunk_size
end_idx = start_idx + chunk_size
# Check if logits_to_keep is in this rank's range
if start_idx <= logits_to_keep < end_idx:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep - start_idx)
return logits_to_keep - start_idx
else:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), -1)
return -1
def apply_sequence_parallelism( def apply_sequence_parallelism(
batch: dict[str, torch.Tensor], batch: dict[str, torch.Tensor],
local_rank: int, local_rank: int,
@@ -32,10 +83,10 @@ def apply_sequence_parallelism(
Apply sequence parallelism slicing to a batch. Apply sequence parallelism slicing to a batch.
Args: Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group local_world_size: World size of the sequence parallel group.
ring_attn_func: The ring attention function to use ring_attn_func: The ring attention function to use.
Returns: Returns:
Sliced batch dictionary. Sliced batch dictionary.
@@ -48,12 +99,10 @@ def apply_sequence_parallelism(
total_seq_len = batch["input_ids"].size(1) total_seq_len = batch["input_ids"].size(1)
for key in batch: for key in batch:
if ( if (
key in batch isinstance(batch[key], torch.Tensor)
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1 and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len and batch[key].size(1) == total_seq_len
): ):
if ring_attn_func in [ if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING, RingAttnFunc.BATCH_RING,
@@ -78,6 +127,14 @@ def apply_sequence_parallelism(
dim=1, dim=1,
).transpose(1, 2) ).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous() batch[key] = tensor[:, local_rank].contiguous()
if key == "logits_to_keep":
batch[key] = _handle_logits_to_keep(
logits_to_keep=batch[key],
local_rank=local_rank,
local_world_size=local_world_size,
ring_attn_func=ring_attn_func,
total_seq_len=total_seq_len,
)
return batch return batch
@@ -205,8 +262,11 @@ class SequenceParallelContextManager:
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output): def sequence_parallel_post_hook(_, __, output):
print("start of sequence_parallel_post_hook")
# Gather the sharded outputs # Gather the sharded outputs
return self.gather_outputs(output) output = self.gather_outputs(output)
print("end of sequence_parallel_post_hook")
return output
# Register both hooks # Register both hooks
self.hook_handles.append( self.hook_handles.append(

View File

@@ -18,7 +18,6 @@ from pydantic import (
) )
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
DatasetConfig, DatasetConfig,

View File

@@ -1,6 +1,4 @@
""" """E2E tests for mixtral"""
E2E tests for mixtral
"""
import logging import logging
import os import os
@@ -99,6 +97,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,12 +12,12 @@ from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
register_ring_attn, register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
@pytest.fixture @pytest.fixture