updates
This commit is contained in:
@@ -1009,6 +1009,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||||
self.cfg.dataloader_prefetch_factor
|
self.cfg.dataloader_prefetch_factor
|
||||||
)
|
)
|
||||||
|
if self.cfg.seed:
|
||||||
|
training_args_kwargs["seed"] = self.cfg.seed
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
self.cfg.gradient_checkpointing
|
self.cfg.gradient_checkpointing
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -14,14 +13,66 @@ from torch.utils.data import DistributedSampler, Sampler
|
|||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
RingAttnFunc,
|
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_logits_to_keep(
|
||||||
|
logits_to_keep,
|
||||||
|
local_rank: int,
|
||||||
|
local_world_size: int,
|
||||||
|
ring_attn_func: RingAttnFunc,
|
||||||
|
total_seq_len: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle logits_to_keep parameter for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits_to_keep: Integer or tensor indicating which positions to compute logits
|
||||||
|
for.
|
||||||
|
local_rank: Rank in the sequence parallel group.
|
||||||
|
local_world_size: World size of the sequence parallel group.
|
||||||
|
ring_attn_func: Ring attention function being used.
|
||||||
|
total_seq_len: Full sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted logits_to_keep appropriate for this rank's sharded sequence
|
||||||
|
"""
|
||||||
|
print("start of _handle_logits_to_keep")
|
||||||
|
print(dist.get_rank(), logits_to_keep)
|
||||||
|
|
||||||
|
# No transformation needed if logits_to_keep is None
|
||||||
|
if logits_to_keep is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
logits_to_keep, int
|
||||||
|
), "sequence parallelism currently only supports integer logits_to_keep"
|
||||||
|
assert ring_attn_func in [
|
||||||
|
RingAttnFunc.VARLEN_LLAMA3,
|
||||||
|
RingAttnFunc.BATCH_RING,
|
||||||
|
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
|
||||||
|
|
||||||
|
# For standard sharding, each rank gets a contiguous chunk
|
||||||
|
chunk_size = total_seq_len // local_world_size
|
||||||
|
start_idx = local_rank * chunk_size
|
||||||
|
end_idx = start_idx + chunk_size
|
||||||
|
|
||||||
|
# Check if logits_to_keep is in this rank's range
|
||||||
|
if start_idx <= logits_to_keep < end_idx:
|
||||||
|
print("end of _handle_logits_to_keep")
|
||||||
|
print(dist.get_rank(), logits_to_keep - start_idx)
|
||||||
|
return logits_to_keep - start_idx
|
||||||
|
else:
|
||||||
|
print("end of _handle_logits_to_keep")
|
||||||
|
print(dist.get_rank(), -1)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
def apply_sequence_parallelism(
|
||||||
batch: dict[str, torch.Tensor],
|
batch: dict[str, torch.Tensor],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
@@ -32,10 +83,10 @@ def apply_sequence_parallelism(
|
|||||||
Apply sequence parallelism slicing to a batch.
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
|
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||||
local_rank: Local rank in the sequence parallel group
|
local_rank: Local rank in the sequence parallel group.
|
||||||
local_world_size: World size of the sequence parallel group
|
local_world_size: World size of the sequence parallel group.
|
||||||
ring_attn_func: The ring attention function to use
|
ring_attn_func: The ring attention function to use.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sliced batch dictionary.
|
Sliced batch dictionary.
|
||||||
@@ -48,12 +99,10 @@ def apply_sequence_parallelism(
|
|||||||
total_seq_len = batch["input_ids"].size(1)
|
total_seq_len = batch["input_ids"].size(1)
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if (
|
if (
|
||||||
key in batch
|
isinstance(batch[key], torch.Tensor)
|
||||||
and isinstance(batch[key], torch.Tensor)
|
|
||||||
and batch[key].dim() > 1
|
and batch[key].dim() > 1
|
||||||
and batch[key].size(1) == total_seq_len
|
and batch[key].size(1) == total_seq_len
|
||||||
):
|
):
|
||||||
|
|
||||||
if ring_attn_func in [
|
if ring_attn_func in [
|
||||||
RingAttnFunc.VARLEN_LLAMA3,
|
RingAttnFunc.VARLEN_LLAMA3,
|
||||||
RingAttnFunc.BATCH_RING,
|
RingAttnFunc.BATCH_RING,
|
||||||
@@ -78,6 +127,14 @@ def apply_sequence_parallelism(
|
|||||||
dim=1,
|
dim=1,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
batch[key] = tensor[:, local_rank].contiguous()
|
batch[key] = tensor[:, local_rank].contiguous()
|
||||||
|
if key == "logits_to_keep":
|
||||||
|
batch[key] = _handle_logits_to_keep(
|
||||||
|
logits_to_keep=batch[key],
|
||||||
|
local_rank=local_rank,
|
||||||
|
local_world_size=local_world_size,
|
||||||
|
ring_attn_func=ring_attn_func,
|
||||||
|
total_seq_len=total_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@@ -205,8 +262,11 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
def sequence_parallel_post_hook(_, __, output):
|
def sequence_parallel_post_hook(_, __, output):
|
||||||
|
print("start of sequence_parallel_post_hook")
|
||||||
# Gather the sharded outputs
|
# Gather the sharded outputs
|
||||||
return self.gather_outputs(output)
|
output = self.gather_outputs(output)
|
||||||
|
print("end of sequence_parallel_post_hook")
|
||||||
|
return output
|
||||||
|
|
||||||
# Register both hooks
|
# Register both hooks
|
||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""E2E tests for mixtral"""
|
||||||
E2E tests for mixtral
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -99,6 +97,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ from accelerate.state import PartialState
|
|||||||
|
|
||||||
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
|
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
RingAttnFunc,
|
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
register_ring_attn,
|
register_ring_attn,
|
||||||
set_ring_attn_group,
|
set_ring_attn_group,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Reference in New Issue
Block a user