temp: trying another approach

This commit is contained in:
Dan Saunders
2025-06-15 21:32:10 +00:00
parent f8f87321bd
commit e34b6f4dfe
4 changed files with 41 additions and 14 deletions

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os import os
from collections import defaultdict from collections import defaultdict
from functools import partial, wraps from functools import partial, wraps
from typing import Callable, Literal, Optional from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
import datasets import datasets
import torch import torch
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
DataLoader, DataLoader,
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# SPDA device mesh init
import torch.distributed as dist
world_size = dist.get_world_size()
mesh_shape = (
world_size // 2,
2,
)
self.world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
ctx_manager = get_context_parallel_manager(
world_mesh=self.world_mesh,
model=model,
)
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
with ctx_manager(list(to_shard.values())):
super().training_step(model, inputs, num_items_in_batch)
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile: if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access

View File

@@ -25,7 +25,6 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders import ( from axolotl.loaders import (
ModelLoader, ModelLoader,
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
def setup_signal_handler( def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
): ):
""" """
Set up signal handler for graceful termination. Set up signal handler for graceful termination.
@@ -202,7 +201,7 @@ def execute_training(
) )
) )
if cfg.context_parallel_degree > 1: if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
# Models to enter context parallel manager for # Models to enter context parallel manager for
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:
@@ -229,7 +228,7 @@ def execute_training(
def save_trained_model( def save_trained_model(
cfg: DictDefault, cfg: DictDefault,
trainer: Any, trainer: Any,
model: PreTrainedModel, model: PeftModel | PreTrainedModel,
safe_serialization: bool, safe_serialization: bool,
): ):
""" """
@@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
def save_initial_configs( def save_initial_configs(
cfg: DictDefault, cfg: DictDefault,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model: PreTrainedModel, model: PeftModel | PreTrainedModel,
peft_config: PeftConfig | None, peft_config: PeftConfig | None,
processor: ProcessorMixin | None, processor: ProcessorMixin | None,
): ):
@@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
def handle_untrained_tokens_fix( def handle_untrained_tokens_fix(
cfg: DictDefault, cfg: DictDefault,
model: PreTrainedModel, model: PeftModel | PreTrainedModel,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
train_dataset: Dataset, train_dataset: Dataset,
safe_serialization: bool, safe_serialization: bool,
@@ -477,7 +476,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder, Trainer,
PeftModel | PreTrainedModel, PeftModel | PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,
PeftConfig | None, PeftConfig | None,

View File

@@ -35,14 +35,14 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
import contextlib import contextlib
from typing import Callable, Generator, Optional, Union from typing import Callable, Generator, Optional, Union
from axolotl.utils.dict import DictDefault
import torch import torch
from torch import nn
from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel
from axolotl.utils.dict import DictDefault
def _get_sdpa_context() -> ( def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
@@ -77,7 +77,7 @@ def _get_sdpa_context() -> (
def get_context_parallel_manager( def get_context_parallel_manager(
*, *,
world_mesh: torch.distributed.DeviceMesh, world_mesh: torch.distributed.DeviceMesh,
model: PreTrainedModel, model: nn.Module,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]: ) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
""" """
Context manager for applying context parallelism to a model. In addition to applying the Context manager for applying context parallelism to a model. In addition to applying the

View File

@@ -70,7 +70,7 @@ class ContextParallelContextManager:
self.process_group = get_ring_attn_group() self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group) self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function # Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial( self.apply_context_parallelism = functools.partial(
apply_context_parallelism, apply_context_parallelism,
@@ -79,7 +79,7 @@ class ContextParallelContextManager:
gradient_accumulation_steps=self.gradient_accumulation_steps, gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )
# Store original sequence length and padding information # Store original sequence length and padding information
self.original_seq_len = 0 self.original_seq_len = 0
self.pad_len = 0 self.pad_len = 0
@@ -95,7 +95,7 @@ class ContextParallelContextManager:
torch.tensor(list(range(world_size))).reshape(mesh_shape), torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"), mesh_dim_names=("dp", "cp"),
) )
# SDPA context parallel managers # SDPA context parallel managers
self.context_parallel_managers = [] self.context_parallel_managers = []
for model in models: for model in models: