temp: trying another approach
This commit is contained in:
@@ -7,11 +7,13 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, Literal, Optional
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# SPDA device mesh init
|
||||
import torch.distributed as dist
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // 2,
|
||||
2,
|
||||
)
|
||||
self.world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=self.world_mesh,
|
||||
model=model,
|
||||
)
|
||||
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
|
||||
with ctx_manager(list(to_shard.values())):
|
||||
super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
|
||||
@@ -25,7 +25,6 @@ from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import (
|
||||
ModelLoader,
|
||||
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
@@ -202,7 +201,7 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.context_parallel_degree > 1:
|
||||
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
|
||||
# Models to enter context parallel manager for
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
@@ -229,7 +228,7 @@ def execute_training(
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
@@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
processor: ProcessorMixin | None,
|
||||
):
|
||||
@@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
@@ -477,7 +476,7 @@ def handle_untrained_tokens_fix(
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
Trainer,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
|
||||
@@ -35,14 +35,14 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
def _get_sdpa_context() -> (
|
||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||
@@ -77,7 +77,7 @@ def _get_sdpa_context() -> (
|
||||
def get_context_parallel_manager(
|
||||
*,
|
||||
world_mesh: torch.distributed.DeviceMesh,
|
||||
model: PreTrainedModel,
|
||||
model: nn.Module,
|
||||
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
||||
"""
|
||||
Context manager for applying context parallelism to a model. In addition to applying the
|
||||
|
||||
@@ -70,7 +70,7 @@ class ContextParallelContextManager:
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
@@ -79,7 +79,7 @@ class ContextParallelContextManager:
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
@@ -95,7 +95,7 @@ class ContextParallelContextManager:
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
|
||||
# SDPA context parallel managers
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
|
||||
Reference in New Issue
Block a user