updates
This commit is contained in:
@@ -1009,6 +1009,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.seed:
|
||||
training_args_kwargs["seed"] = self.cfg.seed
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -14,14 +13,66 @@ from torch.utils.data import DistributedSampler, Sampler
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_logits_to_keep(
|
||||
logits_to_keep,
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
total_seq_len: int,
|
||||
):
|
||||
"""
|
||||
Handle logits_to_keep parameter for sequence parallelism.
|
||||
|
||||
Args:
|
||||
logits_to_keep: Integer or tensor indicating which positions to compute logits
|
||||
for.
|
||||
local_rank: Rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
ring_attn_func: Ring attention function being used.
|
||||
total_seq_len: Full sequence length.
|
||||
|
||||
Returns:
|
||||
Adjusted logits_to_keep appropriate for this rank's sharded sequence
|
||||
"""
|
||||
print("start of _handle_logits_to_keep")
|
||||
print(dist.get_rank(), logits_to_keep)
|
||||
|
||||
# No transformation needed if logits_to_keep is None
|
||||
if logits_to_keep is None:
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
logits_to_keep, int
|
||||
), "sequence parallelism currently only supports integer logits_to_keep"
|
||||
assert ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
|
||||
|
||||
# For standard sharding, each rank gets a contiguous chunk
|
||||
chunk_size = total_seq_len // local_world_size
|
||||
start_idx = local_rank * chunk_size
|
||||
end_idx = start_idx + chunk_size
|
||||
|
||||
# Check if logits_to_keep is in this rank's range
|
||||
if start_idx <= logits_to_keep < end_idx:
|
||||
print("end of _handle_logits_to_keep")
|
||||
print(dist.get_rank(), logits_to_keep - start_idx)
|
||||
return logits_to_keep - start_idx
|
||||
else:
|
||||
print("end of _handle_logits_to_keep")
|
||||
print(dist.get_rank(), -1)
|
||||
return -1
|
||||
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
@@ -32,10 +83,10 @@ def apply_sequence_parallelism(
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
|
||||
local_rank: Local rank in the sequence parallel group
|
||||
local_world_size: World size of the sequence parallel group
|
||||
ring_attn_func: The ring attention function to use
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||
local_rank: Local rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
ring_attn_func: The ring attention function to use.
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
@@ -48,12 +99,10 @@ def apply_sequence_parallelism(
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
for key in batch:
|
||||
if (
|
||||
key in batch
|
||||
and isinstance(batch[key], torch.Tensor)
|
||||
isinstance(batch[key], torch.Tensor)
|
||||
and batch[key].dim() > 1
|
||||
and batch[key].size(1) == total_seq_len
|
||||
):
|
||||
|
||||
if ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
@@ -78,6 +127,14 @@ def apply_sequence_parallelism(
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
batch[key] = tensor[:, local_rank].contiguous()
|
||||
if key == "logits_to_keep":
|
||||
batch[key] = _handle_logits_to_keep(
|
||||
logits_to_keep=batch[key],
|
||||
local_rank=local_rank,
|
||||
local_world_size=local_world_size,
|
||||
ring_attn_func=ring_attn_func,
|
||||
total_seq_len=total_seq_len,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -205,8 +262,11 @@ class SequenceParallelContextManager:
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output):
|
||||
print("start of sequence_parallel_post_hook")
|
||||
# Gather the sharded outputs
|
||||
return self.gather_outputs(output)
|
||||
output = self.gather_outputs(output)
|
||||
print("end of sequence_parallel_post_hook")
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
self.hook_handles.append(
|
||||
|
||||
@@ -18,7 +18,6 @@ from pydantic import (
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DatasetConfig,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
E2E tests for mixtral
|
||||
"""
|
||||
"""E2E tests for mixtral"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -99,6 +97,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,12 +12,12 @@ from accelerate.state import PartialState
|
||||
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user