diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 70e443cb3..850569985 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,11 +7,13 @@ from __future__ import annotations import os from collections import defaultdict from functools import partial, wraps -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional +from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager import datasets import torch from datasets import Dataset +from torch import nn from torch.utils.data import ( BatchSampler, DataLoader, @@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + # SPDA device mesh init + import torch.distributed as dist + + world_size = dist.get_world_size() + mesh_shape = ( + world_size // 2, + 2, + ) + self.world_mesh = dist.DeviceMesh( + "cuda", + torch.tensor(list(range(world_size))).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None + ) -> torch.Tensor: + ctx_manager = get_context_parallel_manager( + world_mesh=self.world_mesh, + model=model, + ) + to_shard = {k: v for k, v in inputs.items() if v.ndim > 1} + with ctx_manager(list(to_shard.values())): + super().training_step(model, inputs, num_items_in_batch) + + def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5d2e77b4f..c32d9550b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -25,7 +25,6 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager from axolotl.loaders import ( ModelLoader, @@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None: def setup_signal_handler( - cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool + cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool ): """ Set up signal handler for graceful termination. @@ -202,7 +201,7 @@ def execute_training( ) ) - if cfg.context_parallel_degree > 1: + if cfg.context_parallel_degree > 1 and not cfg.sdp_attention: # Models to enter context parallel manager for models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: @@ -229,7 +228,7 @@ def execute_training( def save_trained_model( cfg: DictDefault, trainer: Any, - model: PreTrainedModel, + model: PeftModel | PreTrainedModel, safe_serialization: bool, ): """ @@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer): def save_initial_configs( cfg: DictDefault, tokenizer: PreTrainedTokenizer, - model: PreTrainedModel, + model: PeftModel | PreTrainedModel, peft_config: PeftConfig | None, processor: ProcessorMixin | None, ): @@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault): def handle_untrained_tokens_fix( cfg: DictDefault, - model: PreTrainedModel, + model: PeftModel | PreTrainedModel, tokenizer: PreTrainedTokenizer, train_dataset: Dataset, safe_serialization: bool, @@ -477,7 +476,7 @@ def handle_untrained_tokens_fix( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ - HFRLTrainerBuilder | HFCausalTrainerBuilder, + Trainer, PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, diff --git a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py index 7d7d774d1..13adeb132 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py @@ -35,14 +35,14 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5 import contextlib from typing import Callable, Generator, Optional, Union -from axolotl.utils.dict import DictDefault import torch +from torch import nn from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention.flex_attention import BlockMask -from transformers import PreTrainedModel +from axolotl.utils.dict import DictDefault def _get_sdpa_context() -> ( Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] @@ -77,7 +77,7 @@ def _get_sdpa_context() -> ( def get_context_parallel_manager( *, world_mesh: torch.distributed.DeviceMesh, - model: PreTrainedModel, + model: nn.Module, ) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]: """ Context manager for applying context parallelism to a model. In addition to applying the diff --git a/src/axolotl/utils/ctx_managers/context_parallel/manager.py b/src/axolotl/utils/ctx_managers/context_parallel/manager.py index bc16c66cd..f2b116b78 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/manager.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/manager.py @@ -70,7 +70,7 @@ class ContextParallelContextManager: self.process_group = get_ring_attn_group() self.local_rank = dist.get_rank(self.process_group) self.local_world_size = dist.get_world_size(self.process_group) - + # Create a partially applied version of the apply_context_parallelism function self.apply_context_parallelism = functools.partial( apply_context_parallelism, @@ -79,7 +79,7 @@ class ContextParallelContextManager: gradient_accumulation_steps=self.gradient_accumulation_steps, ring_attn_func=self.ring_attn_func, ) - + # Store original sequence length and padding information self.original_seq_len = 0 self.pad_len = 0 @@ -95,7 +95,7 @@ class ContextParallelContextManager: torch.tensor(list(range(world_size))).reshape(mesh_shape), mesh_dim_names=("dp", "cp"), ) - + # SDPA context parallel managers self.context_parallel_managers = [] for model in models: