Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
|
||||
@@ -24,9 +24,11 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
@@ -434,8 +436,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
partial_state = PartialState()
|
||||
has_pc_attr = (
|
||||
hasattr(partial_state, "parallelism_config")
|
||||
and partial_state.parallelism_config
|
||||
)
|
||||
has_pc_key = (
|
||||
"parallelism_config"
|
||||
in partial_state._shared_state # pylint: disable=protected-access
|
||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||
"parallelism_config"
|
||||
]
|
||||
)
|
||||
use_configured_state = has_pc_attr or has_pc_key
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state,
|
||||
)
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
|
||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -50,6 +51,7 @@ class AxolotlTrainer(
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from torch import nn
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DPOTrainer,
|
||||
DistributedParallelMixin,
|
||||
):
|
||||
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||
|
||||
|
||||
@@ -82,14 +82,14 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.context_parallel_size > 1:
|
||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||
|
||||
if trl.importance_sampling_level is not None:
|
||||
grpo_args_kwargs["importance_sampling_level"] = (
|
||||
trl.importance_sampling_level
|
||||
)
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.num_sp_groups = world_size // context_parallel_size
|
||||
self.sp_group_id = rank // context_parallel_size
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
|
||||
@@ -53,7 +57,12 @@ if is_peft_available():
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
GRPOTrainer,
|
||||
):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
// self.args.context_parallel_size,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
context_parallel_size=self.args.context_parallel_size,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
context_parallel_size = self.args.context_parallel_size
|
||||
num_sp_groups = world_size // context_parallel_size
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
group_leader_rank = sp_group_id * context_parallel_size
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
:: self.num_generations * self.args.context_parallel_size
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
len(all_prompts_text) // self.args.context_parallel_size
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
|
||||
@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
except (NotImplementedError, KeyError) as exc:
|
||||
# TODO: fix fsdp2 optimizer saving
|
||||
LOG.warning_once(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible."
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class DistributedParallelMixin(Trainer):
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
def _save(self, output_dir: str | None = None, state_dict=None):
|
||||
if (
|
||||
state_dict is None
|
||||
and self.accelerator.parallelism_config
|
||||
and self.accelerator.parallelism_config.dp_shard_enabled
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
@@ -8,13 +8,18 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
ORPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
KTOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
CPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
RewardTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
PRMTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
|
||||
@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
|
||||
Input args for LIGER.
|
||||
"""
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
liger_cross_entropy: bool | None = None
|
||||
liger_fused_linear_cross_entropy: bool | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_liger_rms_norm_tensor_parallel(cls, data):
|
||||
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
||||
raise ValueError(
|
||||
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
||||
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
|
||||
raise ValueError("Tensor parallelism is not compatible with liger losses.")
|
||||
return self
|
||||
|
||||
@@ -13,7 +13,8 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
PeftMixedModel,
|
||||
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -87,6 +85,9 @@ class ModelLoader:
|
||||
`AutoModelForCausalLM`).
|
||||
"""
|
||||
|
||||
use_parallel_config: bool | None = False
|
||||
parallelism_config: ParallelismConfig | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
@@ -183,6 +184,20 @@ class ModelLoader:
|
||||
|
||||
def _apply_pre_model_load_setup(self):
|
||||
"""Apply patches and setup configurations before model loading."""
|
||||
if self.use_parallel_config is not None:
|
||||
self.use_parallel_config = (
|
||||
self.cfg.fsdp_config
|
||||
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
|
||||
or (
|
||||
self.cfg.context_parallel_size
|
||||
and self.cfg.context_parallel_size > 1
|
||||
)
|
||||
)
|
||||
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
|
||||
self.use_parallel_config = False
|
||||
|
||||
if self.use_parallel_config:
|
||||
self._set_parallel_config()
|
||||
self._set_auto_model_loader()
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
@@ -390,6 +405,86 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
self.cfg.dp_shard_size,
|
||||
self.cfg.dp_replicate_size,
|
||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
self.parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||
partial_state = PartialState()
|
||||
# fmt: off
|
||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||
self.parallelism_config
|
||||
)
|
||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||
device_mesh
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||
@@ -622,6 +717,14 @@ class ModelLoader:
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
@@ -734,6 +837,14 @@ class ModelLoader:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
|
||||
# TODO(wing): remove once 4.54.1 is released
|
||||
if self.model._tp_size != self.cfg.tensor_parallel_size:
|
||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def _set_z3_leaf_modules(self):
|
||||
|
||||
@@ -49,6 +49,7 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
self._apply_transformers_patches()
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
@@ -64,13 +65,19 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||
patch_prepare_from_posids,
|
||||
)
|
||||
|
||||
patch_prepare_from_posids()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -253,17 +260,6 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import (
|
||||
|
||||
@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||
)
|
||||
|
||||
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": (
|
||||
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
|
||||
if mesh is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
for _, param in model.named_parameters():
|
||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||
|
||||
@@ -5,18 +5,14 @@
|
||||
|
||||
from .patch import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
register_ring_attn_from_device_mesh,
|
||||
set_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"get_ring_attn_group",
|
||||
"patch_prepare_data_loader",
|
||||
"patch_prepare_device_mesh",
|
||||
"register_ring_attn",
|
||||
"register_ring_attn_from_device_mesh",
|
||||
"set_ring_attn_group",
|
||||
"update_ring_attn_params",
|
||||
)
|
||||
|
||||
@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
|
||||
sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import DeviceMesh
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // submesh_tp_size"""
|
||||
|
||||
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
submesh_cp_size = 1
|
||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
|
||||
|
||||
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""Getter for ring attention group on this rank."""
|
||||
if RING_ATTN_GROUP is None:
|
||||
raise RuntimeError("register_ring_attn() not yet called")
|
||||
raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
|
||||
]
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
def register_ring_attn_from_device_mesh(
|
||||
device_mesh: "DeviceMesh",
|
||||
context_parallel_dim: tuple[str, ...],
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
"""Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
device_mesh: DeviceMesh object containing the parallelism topology.
|
||||
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
@@ -177,44 +152,39 @@ def register_ring_attn(
|
||||
`batch` function.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
LOG.info(
|
||||
f"Enabling ring attention sequence parallelism using DeviceMesh "
|
||||
f"dimension '{context_parallel_dim}'",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
# Extract the sequence parallel submesh
|
||||
try:
|
||||
sequence_mesh = device_mesh[context_parallel_dim]
|
||||
except (KeyError, IndexError) as e:
|
||||
raise ValueError(
|
||||
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
|
||||
f"Available dimensions: {device_mesh.mesh_dim_names}"
|
||||
) from e
|
||||
|
||||
# Get the process group for context parallelism
|
||||
sequence_pg = sequence_mesh.get_group()
|
||||
context_parallel_size = sequence_mesh.size()
|
||||
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
f"Sequence parallel degree: {context_parallel_size}, "
|
||||
f"mesh shape: {sequence_mesh.mesh.shape}"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
# Log which ranks are in the current process group
|
||||
if sequence_pg != dist.GroupMember.WORLD:
|
||||
ranks_in_group = dist.get_process_group_ranks(sequence_pg)
|
||||
LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
|
||||
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
|
||||
# Track which GPUs are in which groups
|
||||
for r in ring_attn_ranks:
|
||||
group_assignments[r] = i
|
||||
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
# Set the ring attention group
|
||||
set_ring_attn_group(sequence_pg)
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
# fmt: off
|
||||
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
"""
|
||||
original_fn = accelerate.data_loader.prepare_data_loader
|
||||
original_source = inspect.getsource(original_fn)
|
||||
|
||||
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
|
||||
raise RuntimeError(
|
||||
"SP patch failed - target snippet not found. "
|
||||
"Check accelerate's version or update the patch."
|
||||
)
|
||||
|
||||
patched_source = original_source.replace(
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(accelerate.data_loader):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Create a new function from the patched source
|
||||
namespace = {}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
patched_source, globals(), namespace
|
||||
)
|
||||
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
original_fn.__code__ = patched_function.__code__
|
||||
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||
fsdp: Whether to use FSDP.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
"""Prepare the device mesh for distributed training. The dataloader will
|
||||
determine how to load data based on the device mesh.
|
||||
"""
|
||||
if self.state.torch_tp_plugin:
|
||||
return self.state.torch_tp_plugin.torch_device_mesh
|
||||
if (
|
||||
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
|
||||
and hasattr(self.state, "ds_device_mesh")
|
||||
):
|
||||
return self.state.ds_device_mesh
|
||||
|
||||
# Create device mesh with sequence parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming.
|
||||
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||
# only use "fsdp" and "cp" for the device mesh.
|
||||
return dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(device_ids).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||
)
|
||||
|
||||
# Replace the original method with our new method
|
||||
# pylint: disable=protected-access
|
||||
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
)
|
||||
|
||||
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/39653/files
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||
),
|
||||
)
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
indices_q,
|
||||
(cu_seq_lens, cu_seq_lens),
|
||||
(max_length, max_length),
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_from_posids():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||
_prepare_from_posids
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_prepare_from_posids",
|
||||
_prepare_from_posids,
|
||||
)
|
||||
@@ -205,7 +205,7 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
if cfg.context_parallel_size > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
@@ -213,7 +213,7 @@ def execute_training(
|
||||
stack.enter_context(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
context_parallel_size=cfg.context_parallel_size,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
|
||||
@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
|
||||
|
||||
@check_cuda_device((0.0, 0.0, 0.0))
|
||||
def gpu_memory_usage_all(device=0):
|
||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||
smi = gpu_memory_usage_smi(device)
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||
return active, allocated, reserved
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
|
||||
return 0.0
|
||||
|
||||
|
||||
def log_gpu_memory_usage(
|
||||
log: logging.Logger | logging.LoggerAdapter,
|
||||
msg: str = "",
|
||||
device: int | torch.device = 0,
|
||||
):
|
||||
def get_gpu_memory_usage(device: int | torch.device = 0):
|
||||
cur_device_type = str(get_device_type())
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in cur_device_type and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
elif "gpu" in cur_device_type and torch.cuda.is_available():
|
||||
elif "cuda" in cur_device_type and torch.cuda.is_available():
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
else:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
return usage, cache, misc
|
||||
|
||||
|
||||
def log_gpu_memory_usage(
|
||||
log: logging.Logger | logging.LoggerAdapter,
|
||||
msg: str = "",
|
||||
device: int | torch.device = 0,
|
||||
):
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage(device)
|
||||
except ValueError:
|
||||
# likely CPU, ignore
|
||||
return
|
||||
cur_device_type = str(get_device_type())
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
msg = f"{cur_device_type} memory usage:" if not msg else msg
|
||||
log.info(
|
||||
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
|
||||
if allocated > 0:
|
||||
extras.append(f"+{allocated:.03f}GB allocated")
|
||||
if reserved > 0:
|
||||
extras.append(f"+{reserved:.03f}GB reserved")
|
||||
msg = f"{cur_device_type} memory active:" if not msg else msg
|
||||
log.debug(
|
||||
f"{msg} {active:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ from transformers.trainer_utils import (
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -100,7 +100,6 @@ class GPUStatsCallback(
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
@@ -109,9 +108,21 @@ class GPUStatsCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> TrainerControl:
|
||||
if not self.logged and state.global_step > 1:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
if state.global_step > 0:
|
||||
if self.cfg.use_wandb and state.is_world_process_zero:
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
wandb.log(
|
||||
{
|
||||
"memory/max_memory_active": active,
|
||||
"memory/max_memory_allocated": allocated,
|
||||
"memory/device_memory_reserved": reserved,
|
||||
},
|
||||
step=state.global_step,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
log_gpu_memory_usage(LOG, "", self.cfg.device)
|
||||
return control
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
register_ring_attn_from_device_mesh,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
|
||||
if "num_items_in_batch" in batch:
|
||||
# Approximation; this needed since num_items_in_batch may be counted across
|
||||
# all samples in a gradient accumulated batch, not on a per-step basis.
|
||||
local_valid_tokens = (batch["labels"] != -100).sum()
|
||||
|
||||
# All-reduce across sequence parallel ranks to get global token count
|
||||
cp_group = get_ring_attn_group()
|
||||
global_valid_tokens = local_valid_tokens.clone()
|
||||
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
|
||||
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
|
||||
global_valid_tokens = int(global_valid_tokens.item())
|
||||
|
||||
batch["num_items_in_batch"] = (
|
||||
batch["labels"] != -100
|
||||
).sum() * gradient_accumulation_steps
|
||||
global_valid_tokens * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
context_parallel_size: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
context_parallel_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
partial_state = PartialState()
|
||||
register_ring_attn_from_device_mesh(
|
||||
device_mesh=partial_state.device_mesh,
|
||||
context_parallel_dim=("cp",),
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
else:
|
||||
min_rows_per_proc = 256
|
||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||
dataset.save_to_disk(
|
||||
str(prepared_ds_path),
|
||||
num_proc=min(max(1, len(dataset) // 8), num_workers),
|
||||
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
|
||||
max_shard_size=None,
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
utils to get GPU info for the current environment
|
||||
"""
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
)
|
||||
from accelerate.utils.environment import (
|
||||
get_gpu_info,
|
||||
)
|
||||
from packaging.version import Version, parse
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def get_package_version(package: str) -> Version:
|
||||
version_str = version(package)
|
||||
return parse(version_str)
|
||||
|
||||
|
||||
def is_package_version_ge(package: str, version_: str) -> bool:
|
||||
package_version = get_package_version(package)
|
||||
return package_version >= parse(version_)
|
||||
|
||||
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
from typing import Iterable, Iterator, Union
|
||||
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
|
||||
_sampled_lens = []
|
||||
for _ in range(self.num_count_samples):
|
||||
self._batches = None # Reset cached batches
|
||||
# log timer for generating batches
|
||||
start_time = time.time()
|
||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||
LOG.debug(f"generate_batches time: {time.time() - start_time}")
|
||||
len_batches = min(_sampled_lens)
|
||||
|
||||
# Gather minimum across all ranks
|
||||
|
||||
@@ -651,7 +651,23 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
dp_shard_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||
},
|
||||
)
|
||||
dp_replicate_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of devices to replicate across."},
|
||||
)
|
||||
sequence_parallel_degree: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Deprecated: use `context_parallel_size` instead"
|
||||
},
|
||||
)
|
||||
context_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||
|
||||
@@ -673,7 +673,7 @@ class RLValidationMixin:
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
and data.get("context_parallel_size", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
@@ -880,51 +880,35 @@ class OptimizationValidationMixin:
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_sharded_state_dict_w_safetensors(self):
|
||||
if (
|
||||
hasattr(self, "fsdp_config")
|
||||
and self.fsdp_config
|
||||
and hasattr(self, "save_safetensors")
|
||||
and self.save_safetensors
|
||||
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
|
||||
and str(getattr(self, "fsdp_version", "1")) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||
)
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
if data.get("deepspeed"):
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {
|
||||
"autotp_size": tensor_parallel_size
|
||||
}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
@@ -1205,13 +1189,18 @@ class ComplexValidationMixin:
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
LOG.warning(
|
||||
"`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
|
||||
)
|
||||
self.context_parallel_size = self.sequence_parallel_degree
|
||||
if not self.context_parallel_size:
|
||||
self.context_parallel_size = 1
|
||||
elif self.context_parallel_size > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
"flash_attention: true must be set with context_parallel_size > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1221,17 +1210,23 @@ class ComplexValidationMixin:
|
||||
)
|
||||
|
||||
try:
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
# pylint: disable=protected-access
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window
|
||||
)
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
f"context_parallel_size={self.context_parallel_size}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1242,7 +1237,7 @@ class ComplexValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
if getattr(self, "context_parallel_size", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
@@ -1259,6 +1254,20 @@ class ComplexValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
class DistributedValidationMixin:
|
||||
"""validation for distributed training."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_optimizer(self):
|
||||
if self.tensor_parallel_size > 1:
|
||||
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
|
||||
raise ValueError(
|
||||
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class ValidationMixin(
|
||||
DatasetValidationMixin,
|
||||
|
||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
)
|
||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user