temp: trying another approach

This commit is contained in:
Dan Saunders
2025-06-15 21:32:10 +00:00
parent f8f87321bd
commit e34b6f4dfe
4 changed files with 41 additions and 14 deletions

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# SPDA device mesh init
import torch.distributed as dist
world_size = dist.get_world_size()
mesh_shape = (
world_size // 2,
2,
)
self.world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
ctx_manager = get_context_parallel_manager(
world_mesh=self.world_mesh,
model=model,
)
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
with ctx_manager(list(to_shard.values())):
super().training_step(model, inputs, num_items_in_batch)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access

View File

@@ -25,7 +25,6 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
@@ -202,7 +201,7 @@ def execute_training(
)
)
if cfg.context_parallel_degree > 1:
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
# Models to enter context parallel manager for
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
@@ -229,7 +228,7 @@ def execute_training(
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
safe_serialization: bool,
):
"""
@@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
peft_config: PeftConfig | None,
processor: ProcessorMixin | None,
):
@@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
@@ -477,7 +476,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,

View File

@@ -35,14 +35,14 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
import contextlib
from typing import Callable, Generator, Optional, Union
from axolotl.utils.dict import DictDefault
import torch
from torch import nn
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel
from axolotl.utils.dict import DictDefault
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
@@ -77,7 +77,7 @@ def _get_sdpa_context() -> (
def get_context_parallel_manager(
*,
world_mesh: torch.distributed.DeviceMesh,
model: PreTrainedModel,
model: nn.Module,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
"""
Context manager for applying context parallelism to a model. In addition to applying the

View File

@@ -70,7 +70,7 @@ class ContextParallelContextManager:
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
@@ -79,7 +79,7 @@ class ContextParallelContextManager:
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
@@ -95,7 +95,7 @@ class ContextParallelContextManager:
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# SDPA context parallel managers
self.context_parallel_managers = []
for model in models: