temp: trying another approach
This commit is contained in:
@@ -7,11 +7,13 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Callable, Literal, Optional
|
from typing import Any, Callable, Literal, Optional
|
||||||
|
|
||||||
|
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
# SPDA device mesh init
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
mesh_shape = (
|
||||||
|
world_size // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
self.world_mesh = dist.DeviceMesh(
|
||||||
|
"cuda",
|
||||||
|
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||||
|
mesh_dim_names=("dp", "cp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ctx_manager = get_context_parallel_manager(
|
||||||
|
world_mesh=self.world_mesh,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
|
||||||
|
with ctx_manager(list(to_shard.values())):
|
||||||
|
super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from axolotl.common.datasets import TrainDatasetMeta
|
|||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders import (
|
from axolotl.loaders import (
|
||||||
ModelLoader,
|
ModelLoader,
|
||||||
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def setup_signal_handler(
|
def setup_signal_handler(
|
||||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Set up signal handler for graceful termination.
|
Set up signal handler for graceful termination.
|
||||||
@@ -202,7 +201,7 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.context_parallel_degree > 1:
|
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
|
||||||
# Models to enter context parallel manager for
|
# Models to enter context parallel manager for
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
@@ -229,7 +228,7 @@ def execute_training(
|
|||||||
def save_trained_model(
|
def save_trained_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
trainer: Any,
|
trainer: Any,
|
||||||
model: PreTrainedModel,
|
model: PeftModel | PreTrainedModel,
|
||||||
safe_serialization: bool,
|
safe_serialization: bool,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|||||||
def save_initial_configs(
|
def save_initial_configs(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
model: PreTrainedModel,
|
model: PeftModel | PreTrainedModel,
|
||||||
peft_config: PeftConfig | None,
|
peft_config: PeftConfig | None,
|
||||||
processor: ProcessorMixin | None,
|
processor: ProcessorMixin | None,
|
||||||
):
|
):
|
||||||
@@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
|
|||||||
|
|
||||||
def handle_untrained_tokens_fix(
|
def handle_untrained_tokens_fix(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
model: PreTrainedModel,
|
model: PeftModel | PreTrainedModel,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
train_dataset: Dataset,
|
train_dataset: Dataset,
|
||||||
safe_serialization: bool,
|
safe_serialization: bool,
|
||||||
@@ -477,7 +476,7 @@ def handle_untrained_tokens_fix(
|
|||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
Trainer,
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
|
|||||||
@@ -35,14 +35,14 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Callable, Generator, Optional, Union
|
from typing import Callable, Generator, Optional, Union
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.distributed.tensor.experimental import context_parallel
|
from torch.distributed.tensor.experimental import context_parallel
|
||||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
def _get_sdpa_context() -> (
|
def _get_sdpa_context() -> (
|
||||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||||
@@ -77,7 +77,7 @@ def _get_sdpa_context() -> (
|
|||||||
def get_context_parallel_manager(
|
def get_context_parallel_manager(
|
||||||
*,
|
*,
|
||||||
world_mesh: torch.distributed.DeviceMesh,
|
world_mesh: torch.distributed.DeviceMesh,
|
||||||
model: PreTrainedModel,
|
model: nn.Module,
|
||||||
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
||||||
"""
|
"""
|
||||||
Context manager for applying context parallelism to a model. In addition to applying the
|
Context manager for applying context parallelism to a model. In addition to applying the
|
||||||
|
|||||||
Reference in New Issue
Block a user