Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||||
pytest -v -n2 \
|
pytest -v --durations=10 -n2 \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
|
|||||||
def run_cmd(cmd: str, run_folder: str):
|
def run_cmd(cmd: str, run_folder: str):
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
|
||||||
|
sp_env = os.environ.copy()
|
||||||
|
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
# Propagate errors from subprocess.
|
||||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
|||||||
ring_attn_func:
|
ring_attn_func:
|
||||||
```
|
```
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
- With 4 GPUs, valid values would be 2 or 4
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
|||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
|||||||
|
|
||||||
## Effect on Batch Size
|
## Effect on Batch Size
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||||
- The number of batches processed per step decreases
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ min_sample_len: 200_000
|
|||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
tiled_mlp: true
|
tiled_mlp: true
|
||||||
sequence_parallel_degree: 8
|
context_parallel_size: 8
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.16.0
|
peft==0.16.0
|
||||||
transformers==4.54.0
|
transformers==4.54.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.9.0
|
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.20.0
|
trl==0.20.0
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
|
|||||||
extras_require_map.pop("vllm")
|
extras_require_map.pop("vllm")
|
||||||
else:
|
else:
|
||||||
_install_requires.append("xformers==0.0.31")
|
_install_requires.append("xformers==0.0.31")
|
||||||
|
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers==0.0.29.post3")
|
_install_requires.append("xformers==0.0.29.post3")
|
||||||
# since we only support 2.6.0+cu126
|
# since we only support 2.6.0+cu126
|
||||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
extras_require_map.pop("vllm")
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
sequence_parallel_degree=None,
|
context_parallel_size=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
fsdp=None,
|
fsdp=None,
|
||||||
fsdp_config=None,
|
fsdp_config=None,
|
||||||
|
|||||||
@@ -24,9 +24,11 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import PartialState
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -434,8 +436,30 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
|
partial_state = PartialState()
|
||||||
|
has_pc_attr = (
|
||||||
|
hasattr(partial_state, "parallelism_config")
|
||||||
|
and partial_state.parallelism_config
|
||||||
|
)
|
||||||
|
has_pc_key = (
|
||||||
|
"parallelism_config"
|
||||||
|
in partial_state._shared_state # pylint: disable=protected-access
|
||||||
|
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||||
|
"parallelism_config"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
use_configured_state = has_pc_attr or has_pc_key
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
use_configured_state = self.cfg.accelerator_config.pop(
|
||||||
|
"use_configured_state", use_configured_state
|
||||||
|
)
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=use_configured_state,
|
||||||
|
)
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from typing_extensions import override
|
|||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
PackingMixin,
|
PackingMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
@@ -50,6 +51,7 @@ class AxolotlTrainer(
|
|||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
ActivationOffloadingMixin,
|
ActivationOffloadingMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
Trainer,
|
Trainer,
|
||||||
):
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import (
|
||||||
|
DistributedParallelMixin,
|
||||||
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
)
|
||||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(
|
class AxolotlDPOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DPOTrainer,
|
||||||
|
DistributedParallelMixin,
|
||||||
):
|
):
|
||||||
"""Extend the base DPOTrainer for axolotl helpers."""
|
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||||
|
|
||||||
|
|||||||
@@ -82,14 +82,14 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
|
if cfg.context_parallel_size > 1:
|
||||||
|
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||||
|
|
||||||
if trl.importance_sampling_level is not None:
|
if trl.importance_sampling_level is not None:
|
||||||
grpo_args_kwargs["importance_sampling_level"] = (
|
grpo_args_kwargs["importance_sampling_level"] = (
|
||||||
trl.importance_sampling_level
|
trl.importance_sampling_level
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
|
||||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
|
||||||
|
|
||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|
||||||
|
|||||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
"""Axolotl GRPO Config for GRPO training"""
|
"""Axolotl GRPO Config for GRPO training"""
|
||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
context_parallel_size: int | None = None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
- Data is properly distributed across SP groups.
|
- Data is properly distributed across SP groups.
|
||||||
|
|
||||||
In the table below, the values represent dataset indices. Each SP group has
|
In the table below, the values represent dataset indices. Each SP group has
|
||||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||||
|
|
||||||
Sequence Parallel Groups
|
Sequence Parallel Groups
|
||||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: Rank of current process.
|
rank: Rank of current process.
|
||||||
batch_size: Number of samples per batch.
|
batch_size: Number of samples per batch.
|
||||||
repeat_count: How many times to repeat the full sampling process.
|
repeat_count: How many times to repeat the full sampling process.
|
||||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||||
shuffle: Whether to shuffle the dataset.
|
shuffle: Whether to shuffle the dataset.
|
||||||
seed: Random seed for shuffling.
|
seed: Random seed for shuffling.
|
||||||
drop_last: Whether to drop the last incomplete batch.
|
drop_last: Whether to drop the last incomplete batch.
|
||||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: int,
|
rank: int,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
repeat_count: int = 1,
|
repeat_count: int = 1,
|
||||||
sequence_parallel_degree: int = 1,
|
context_parallel_size: int = 1,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
# Sequence parallelism parameters
|
# Sequence parallelism parameters
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.context_parallel_size = context_parallel_size
|
||||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
self.num_sp_groups = world_size // context_parallel_size
|
||||||
self.sp_group_id = rank // sequence_parallel_degree
|
self.sp_group_id = rank // context_parallel_size
|
||||||
|
|
||||||
# Adjust dataset size for distributed sampling
|
# Adjust dataset size for distributed sampling
|
||||||
self.num_samples = len(self.dataset)
|
self.num_samples = len(self.dataset)
|
||||||
|
|||||||
@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
|||||||
from trl.trainer.utils import pad
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import (
|
||||||
|
DistributedParallelMixin,
|
||||||
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
)
|
||||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
@@ -53,7 +57,12 @@ if is_peft_available():
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(
|
class AxolotlGRPOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
GRPOTrainer,
|
||||||
):
|
):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
# Get number of SP groups (number of processes divided by SP degree)
|
# Get number of SP groups (number of processes divided by SP degree)
|
||||||
num_processes = self.accelerator.num_processes
|
num_processes = self.accelerator.num_processes
|
||||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||||
|
|
||||||
# Calculate batch size per SP group (not per process)
|
# Calculate batch size per SP group (not per process)
|
||||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||||
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
if self.num_generations not in possible_values:
|
if self.num_generations not in possible_values:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||||
f"must be evenly divisible by the number of generations per prompt "
|
f"must be evenly divisible by the number of generations per prompt "
|
||||||
f"({self.num_generations}). Given the current eval batch size, "
|
f"({self.num_generations}). Given the current eval batch size, "
|
||||||
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
batch_size=effective_batch_size
|
batch_size=effective_batch_size
|
||||||
// self.num_generations
|
// self.num_generations
|
||||||
// self.args.sequence_parallel_degree,
|
// self.args.context_parallel_size,
|
||||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
context_parallel_size=self.args.context_parallel_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=self.args.seed,
|
seed=self.args.seed,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
# slice each batch along the sequence dimension).
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
# Otherwise prepare with accelerator
|
||||||
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||||
all_prompts_text = gather_object(prompts_text)
|
all_prompts_text = gather_object(prompts_text)
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate sequence parallel group information
|
# Calculate sequence parallel group information
|
||||||
world_size = self.accelerator.num_processes
|
world_size = self.accelerator.num_processes
|
||||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
context_parallel_size = self.args.context_parallel_size
|
||||||
num_sp_groups = world_size // sequence_parallel_degree
|
num_sp_groups = world_size // context_parallel_size
|
||||||
|
|
||||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||||
# we only take one copy of each prompt from each SP group
|
# we only take one copy of each prompt from each SP group
|
||||||
ordered_set_of_prompts = []
|
ordered_set_of_prompts = []
|
||||||
for sp_group_id in range(num_sp_groups):
|
for sp_group_id in range(num_sp_groups):
|
||||||
# Get the first process from each SP group (typically the group leader)
|
# Get the first process from each SP group (typically the group leader)
|
||||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
group_leader_rank = sp_group_id * context_parallel_size
|
||||||
|
|
||||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||||
# We only need prompts from one rank in each SP group
|
# We only need prompts from one rank in each SP group
|
||||||
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||||
# prompt individually.
|
# prompt individually.
|
||||||
ordered_set_of_prompts = all_prompts_text[
|
ordered_set_of_prompts = all_prompts_text[
|
||||||
:: self.num_generations * self.args.sequence_parallel_degree
|
:: self.num_generations * self.args.context_parallel_size
|
||||||
]
|
]
|
||||||
|
|
||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_ids = [None] * (
|
completion_ids = [None] * (
|
||||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
len(all_prompts_text) // self.args.context_parallel_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Broadcast the completions from the main process to all processes
|
# Broadcast the completions from the main process to all processes
|
||||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||||
|
|
||||||
# Determine the appropriate slice based on sequence parallelism
|
# Determine the appropriate slice based on sequence parallelism
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||||
|
|
||||||
# Slice to keep only the local part of the data
|
# Slice to keep only the local part of the data
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-ancestors
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""Mamba specific trainer to handle loss calculation"""
|
"""Mamba specific trainer to handle loss calculation"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
from .activation_checkpointing import ActivationOffloadingMixin
|
from .activation_checkpointing import ActivationOffloadingMixin
|
||||||
from .checkpoints import CheckpointSaveMixin
|
from .checkpoints import CheckpointSaveMixin
|
||||||
|
from .distributed_parallel import DistributedParallelMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .packing import PackingMixin
|
from .packing import PackingMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
|
|||||||
def _save_optimizer_and_scheduler(self, output_dir):
|
def _save_optimizer_and_scheduler(self, output_dir):
|
||||||
try:
|
try:
|
||||||
super()._save_optimizer_and_scheduler(output_dir)
|
super()._save_optimizer_and_scheduler(output_dir)
|
||||||
except NotImplementedError as exc:
|
except (NotImplementedError, KeyError) as exc:
|
||||||
LOG.warning(
|
# TODO: fix fsdp2 optimizer saving
|
||||||
|
LOG.warning_once(
|
||||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||||
"for this training run will not be possible."
|
"for this training run will not be possible.",
|
||||||
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
|
|||||||
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Mixin for correctly saving fsdp
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedParallelMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin for correctly saving fsdp
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _save(self, output_dir: str | None = None, state_dict=None):
|
||||||
|
if (
|
||||||
|
state_dict is None
|
||||||
|
and self.accelerator.parallelism_config
|
||||||
|
and self.accelerator.parallelism_config.dp_shard_enabled
|
||||||
|
):
|
||||||
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
|
super()._save(output_dir, state_dict=state_dict)
|
||||||
@@ -8,13 +8,18 @@ from trl import (
|
|||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
|
||||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(
|
class AxolotlORPOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
ORPOTrainer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(
|
class AxolotlKTOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
KTOTrainer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(
|
class AxolotlCPOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
CPOTrainer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(
|
class AxolotlRewardTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
RewardTrainer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(
|
class AxolotlPRMTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
|
RngLoaderMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
OptimizerInitMixin,
|
||||||
|
DistributedParallelMixin,
|
||||||
|
PRMTrainer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
|||||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-ancestors
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
|
|||||||
Input args for LIGER.
|
Input args for LIGER.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
liger_rope: Optional[bool] = None
|
liger_rope: bool | None = None
|
||||||
liger_rms_norm: Optional[bool] = None
|
liger_rms_norm: bool | None = None
|
||||||
liger_layer_norm: Optional[bool] = None
|
liger_layer_norm: bool | None = None
|
||||||
liger_swiglu: Optional[bool] = None
|
liger_swiglu: bool | None = None
|
||||||
liger_glu_activation: Optional[bool] = None
|
liger_glu_activation: bool | None = None
|
||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: bool | None = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: bool | None = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
|
|||||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_liger_rms_norm_tensor_parallel(cls, data):
|
||||||
|
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
||||||
|
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||||
|
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||||
|
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
|
||||||
|
raise ValueError("Tensor parallelism is not compatible with liger losses.")
|
||||||
|
return self
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import init_empty_weights
|
from accelerate import PartialState, init_empty_weights
|
||||||
|
from accelerate.parallelism_config import ParallelismConfig
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
PeftMixedModel,
|
PeftMixedModel,
|
||||||
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
|
|||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||||
get_device_count,
|
|
||||||
get_device_type,
|
|
||||||
)
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -87,6 +85,9 @@ class ModelLoader:
|
|||||||
`AutoModelForCausalLM`).
|
`AutoModelForCausalLM`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
use_parallel_config: bool | None = False
|
||||||
|
parallelism_config: ParallelismConfig | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -183,6 +184,20 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _apply_pre_model_load_setup(self):
|
def _apply_pre_model_load_setup(self):
|
||||||
"""Apply patches and setup configurations before model loading."""
|
"""Apply patches and setup configurations before model loading."""
|
||||||
|
if self.use_parallel_config is not None:
|
||||||
|
self.use_parallel_config = (
|
||||||
|
self.cfg.fsdp_config
|
||||||
|
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
|
||||||
|
or (
|
||||||
|
self.cfg.context_parallel_size
|
||||||
|
and self.cfg.context_parallel_size > 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
|
||||||
|
self.use_parallel_config = False
|
||||||
|
|
||||||
|
if self.use_parallel_config:
|
||||||
|
self._set_parallel_config()
|
||||||
self._set_auto_model_loader()
|
self._set_auto_model_loader()
|
||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
@@ -390,6 +405,86 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_parallel_config_kwargs(
|
||||||
|
world_size: int,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
context_parallel_size: int = 1,
|
||||||
|
dp_shard_size: int | None = None,
|
||||||
|
dp_replicate_size: int | None = None,
|
||||||
|
is_fsdp: bool = False,
|
||||||
|
):
|
||||||
|
pc_kwargs = {}
|
||||||
|
remaining_world_size = world_size
|
||||||
|
|
||||||
|
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||||
|
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||||
|
|
||||||
|
if context_parallel_size and context_parallel_size > 1:
|
||||||
|
pc_kwargs["cp_size"] = context_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // context_parallel_size
|
||||||
|
|
||||||
|
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if dp_replicate_size and dp_replicate_size > 1:
|
||||||
|
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||||
|
|
||||||
|
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||||
|
if not is_fsdp:
|
||||||
|
raise ValueError(
|
||||||
|
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||||
|
"Please ensure you have configured FSDP using fsdp_config."
|
||||||
|
)
|
||||||
|
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_shard_size
|
||||||
|
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||||
|
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||||
|
f"{pc_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pc_kwargs
|
||||||
|
|
||||||
|
def _set_parallel_config(self):
|
||||||
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
|
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||||
|
get_world_size(),
|
||||||
|
self.cfg.tensor_parallel_size,
|
||||||
|
self.cfg.context_parallel_size,
|
||||||
|
self.cfg.dp_shard_size,
|
||||||
|
self.cfg.dp_replicate_size,
|
||||||
|
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
if pc_kwargs:
|
||||||
|
self.parallelism_config = ParallelismConfig(
|
||||||
|
**pc_kwargs,
|
||||||
|
)
|
||||||
|
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||||
|
partial_state = PartialState()
|
||||||
|
# fmt: off
|
||||||
|
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||||
|
self.parallelism_config
|
||||||
|
)
|
||||||
|
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||||
|
device_mesh
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
@@ -622,6 +717,14 @@ class ModelLoader:
|
|||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
|
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||||
|
if "device_map" in self.model_kwargs:
|
||||||
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
@@ -734,6 +837,14 @@ class ModelLoader:
|
|||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
|
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
|
||||||
|
# TODO(wing): remove once 4.54.1 is released
|
||||||
|
if self.model._tp_size != self.cfg.tensor_parallel_size:
|
||||||
|
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||||
|
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||||
|
|
||||||
return skip_move_to_device
|
return skip_move_to_device
|
||||||
|
|
||||||
def _set_z3_leaf_modules(self):
|
def _set_z3_leaf_modules(self):
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class PatchManager:
|
|||||||
|
|
||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
|
self._apply_transformers_patches()
|
||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
self._apply_chunked_cross_entropy_patch()
|
||||||
@@ -64,13 +65,19 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_sequence_parallel_patches()
|
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
self._apply_voxtral_patches()
|
self._apply_voxtral_patches()
|
||||||
|
|
||||||
|
def _apply_transformers_patches(self):
|
||||||
|
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||||
|
patch_prepare_from_posids,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_from_posids()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
@@ -253,17 +260,6 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_sequence_parallel_patches(self):
|
|
||||||
"""Apply sequence parallelism patches."""
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.ring_attn.patch import (
|
|
||||||
patch_prepare_data_loader,
|
|
||||||
patch_prepare_device_mesh,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_prepare_data_loader()
|
|
||||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import (
|
from axolotl.monkeypatch.tiled_mlp import (
|
||||||
|
|||||||
@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||||
|
|
||||||
fsdp2_kwargs = {
|
fsdp2_kwargs = {
|
||||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||||
|
"mesh": (
|
||||||
|
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
|
||||||
|
if mesh is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
model_has_params4bit = False
|
model_has_params4bit = False
|
||||||
for _, param in model.named_parameters():
|
for _, param in model.named_parameters():
|
||||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||||
|
|||||||
@@ -5,18 +5,14 @@
|
|||||||
|
|
||||||
from .patch import (
|
from .patch import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
patch_prepare_data_loader,
|
register_ring_attn_from_device_mesh,
|
||||||
patch_prepare_device_mesh,
|
|
||||||
register_ring_attn,
|
|
||||||
set_ring_attn_group,
|
set_ring_attn_group,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"get_ring_attn_group",
|
"get_ring_attn_group",
|
||||||
"patch_prepare_data_loader",
|
"register_ring_attn_from_device_mesh",
|
||||||
"patch_prepare_device_mesh",
|
|
||||||
"register_ring_attn",
|
|
||||||
"set_ring_attn_group",
|
"set_ring_attn_group",
|
||||||
"update_ring_attn_params",
|
"update_ring_attn_params",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
|
|||||||
sequence parallelism training.
|
sequence parallelism training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import DeviceMesh
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
|
||||||
submesh_dp_size = 1
|
|
||||||
submesh_tp_size = 1
|
|
||||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
|
||||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
|
||||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
|
||||||
process_index = process_index // submesh_tp_size"""
|
|
||||||
|
|
||||||
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
|
||||||
submesh_dp_size = 1
|
|
||||||
submesh_tp_size = 1
|
|
||||||
submesh_cp_size = 1
|
|
||||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
|
||||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
|
||||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
|
||||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
|
||||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
|
||||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
"""Getter for ring attention group on this rank."""
|
"""Getter for ring attention group on this rank."""
|
||||||
if RING_ATTN_GROUP is None:
|
if RING_ATTN_GROUP is None:
|
||||||
raise RuntimeError("register_ring_attn() not yet called")
|
raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
|
||||||
return RING_ATTN_GROUP
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(
|
def register_ring_attn_from_device_mesh(
|
||||||
sequence_parallel_degree: int,
|
device_mesh: "DeviceMesh",
|
||||||
|
context_parallel_dim: tuple[str, ...],
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
ring_attn_func: RingAttnFunc | None,
|
ring_attn_func: RingAttnFunc | None,
|
||||||
):
|
):
|
||||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
"""Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
device_mesh: DeviceMesh object containing the parallelism topology.
|
||||||
|
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
`varlen_llama3` `ring_flash_attn` implementation.
|
`varlen_llama3` `ring_flash_attn` implementation.
|
||||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||||
@@ -177,44 +152,39 @@ def register_ring_attn(
|
|||||||
`batch` function.
|
`batch` function.
|
||||||
"""
|
"""
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
|
||||||
|
LOG.info(
|
||||||
|
f"Enabling ring attention sequence parallelism using DeviceMesh "
|
||||||
|
f"dimension '{context_parallel_dim}'",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the sequence parallel submesh
|
||||||
|
try:
|
||||||
|
sequence_mesh = device_mesh[context_parallel_dim]
|
||||||
|
except (KeyError, IndexError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
|
||||||
|
f"Available dimensions: {device_mesh.mesh_dim_names}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Get the process group for context parallelism
|
||||||
|
sequence_pg = sequence_mesh.get_group()
|
||||||
|
context_parallel_size = sequence_mesh.size()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
f"Sequence parallel degree: {context_parallel_size}, "
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
f"mesh shape: {sequence_mesh.mesh.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
# Log which ranks are in the current process group
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
if sequence_pg != dist.GroupMember.WORLD:
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
ranks_in_group = dist.get_process_group_ranks(sequence_pg)
|
||||||
)
|
LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must evenly divide world_size ({world_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign ranks to sequence parallel groups
|
# Set the ring attention group
|
||||||
group_assignments = {}
|
set_ring_attn_group(sequence_pg)
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
|
||||||
ring_attn_ranks = list(
|
|
||||||
range(
|
|
||||||
i * sequence_parallel_degree,
|
|
||||||
(i + 1) * sequence_parallel_degree,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
|
||||||
|
|
||||||
# Track which GPUs are in which groups
|
|
||||||
for r in ring_attn_ranks:
|
|
||||||
group_assignments[r] = i
|
|
||||||
|
|
||||||
if rank in ring_attn_ranks:
|
|
||||||
set_ring_attn_group(group)
|
|
||||||
|
|
||||||
# Log the GPU group assignments
|
|
||||||
if rank == 0:
|
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
|
||||||
|
|
||||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
|||||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_data_loader():
|
|
||||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If source code to patch does not exist.
|
|
||||||
"""
|
|
||||||
original_fn = accelerate.data_loader.prepare_data_loader
|
|
||||||
original_source = inspect.getsource(original_fn)
|
|
||||||
|
|
||||||
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
|
|
||||||
raise RuntimeError(
|
|
||||||
"SP patch failed - target snippet not found. "
|
|
||||||
"Check accelerate's version or update the patch."
|
|
||||||
)
|
|
||||||
|
|
||||||
patched_source = original_source.replace(
|
|
||||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
|
||||||
)
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(accelerate.data_loader):
|
|
||||||
if item in patched_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
# Create a new function from the patched source
|
|
||||||
namespace = {}
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
patched_source, globals(), namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
patched_function = namespace["prepare_data_loader"]
|
|
||||||
original_fn.__code__ = patched_function.__code__
|
|
||||||
|
|
||||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
|
||||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
|
||||||
that includes sequence parallelism with the specified degree.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
|
||||||
fsdp: Whether to use FSDP.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _prepare_device_mesh(self):
|
|
||||||
"""Prepare the device mesh for distributed training. The dataloader will
|
|
||||||
determine how to load data based on the device mesh.
|
|
||||||
"""
|
|
||||||
if self.state.torch_tp_plugin:
|
|
||||||
return self.state.torch_tp_plugin.torch_device_mesh
|
|
||||||
if (
|
|
||||||
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
|
|
||||||
and hasattr(self.state, "ds_device_mesh")
|
|
||||||
):
|
|
||||||
return self.state.ds_device_mesh
|
|
||||||
|
|
||||||
# Create device mesh with sequence parallelism
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
mesh_shape = (
|
|
||||||
world_size // sequence_parallel_degree,
|
|
||||||
sequence_parallel_degree,
|
|
||||||
)
|
|
||||||
device_ids = list(range(world_size))
|
|
||||||
|
|
||||||
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
|
||||||
# parallelism" implementation naming.
|
|
||||||
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
|
||||||
# only use "fsdp" and "cp" for the device mesh.
|
|
||||||
return dist.DeviceMesh(
|
|
||||||
"cuda",
|
|
||||||
torch.tensor(device_ids).reshape(mesh_shape),
|
|
||||||
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replace the original method with our new method
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Successfully patched Accelerator._prepare_device_mesh "
|
|
||||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
|
||||||
)
|
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||||
|
|
||||||
|
see https://github.com/huggingface/transformers/pull/39653/files
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_from_posids(query, key, value, position_ids):
|
||||||
|
"""
|
||||||
|
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||||
|
All three query, key, value states will be flattened.
|
||||||
|
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||||
|
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||||
|
Arguments:
|
||||||
|
query (`torch.Tensor`):
|
||||||
|
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||||
|
key (`torch.Tensor`):
|
||||||
|
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||||
|
position_ids (`torch.Tensor`):
|
||||||
|
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||||
|
Return:
|
||||||
|
query (`torch.Tensor`):
|
||||||
|
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||||
|
key (`torch.Tensor`):
|
||||||
|
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||||
|
indices_q (`torch.Tensor`):
|
||||||
|
The indices of non-masked tokens from the flattened input target sequence.
|
||||||
|
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||||
|
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||||
|
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||||
|
"""
|
||||||
|
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
indices_q = torch.arange(
|
||||||
|
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seq_lens = torch.cat(
|
||||||
|
(
|
||||||
|
indices_q[position_ids == 0],
|
||||||
|
torch.tensor(
|
||||||
|
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||||
|
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||||
|
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||||
|
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||||
|
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||||
|
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||||
|
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||||
|
# for some models (e.g. qwen2-vl).
|
||||||
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
return (
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
indices_q,
|
||||||
|
(cu_seq_lens, cu_seq_lens),
|
||||||
|
(max_length, max_length),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_from_posids():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||||
|
_prepare_from_posids
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||||
|
"_prepare_from_posids",
|
||||||
|
_prepare_from_posids,
|
||||||
|
)
|
||||||
@@ -205,7 +205,7 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
if cfg.context_parallel_size > 1:
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
models.append(trainer.ref_model)
|
models.append(trainer.ref_model)
|
||||||
@@ -213,7 +213,7 @@ def execute_training(
|
|||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
SequenceParallelContextManager(
|
SequenceParallelContextManager(
|
||||||
models=models,
|
models=models,
|
||||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
context_parallel_size=cfg.context_parallel_size,
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
|||||||
@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
|
|||||||
|
|
||||||
@check_cuda_device((0.0, 0.0, 0.0))
|
@check_cuda_device((0.0, 0.0, 0.0))
|
||||||
def gpu_memory_usage_all(device=0):
|
def gpu_memory_usage_all(device=0):
|
||||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||||
smi = gpu_memory_usage_smi(device)
|
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||||
return usage, reserved - usage, max(0, smi - reserved)
|
return active, allocated, reserved
|
||||||
|
|
||||||
|
|
||||||
def mps_memory_usage_all():
|
def mps_memory_usage_all():
|
||||||
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(
|
def get_gpu_memory_usage(device: int | torch.device = 0):
|
||||||
log: logging.Logger | logging.LoggerAdapter,
|
|
||||||
msg: str = "",
|
|
||||||
device: int | torch.device = 0,
|
|
||||||
):
|
|
||||||
cur_device_type = str(get_device_type())
|
cur_device_type = str(get_device_type())
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
usage, cache, misc = mps_memory_usage_all()
|
usage, cache, misc = mps_memory_usage_all()
|
||||||
elif "npu" in cur_device_type and is_torch_npu_available():
|
elif "npu" in cur_device_type and is_torch_npu_available():
|
||||||
usage, cache, misc = npu_memory_usage_all(device)
|
usage, cache, misc = npu_memory_usage_all(device)
|
||||||
elif "gpu" in cur_device_type and torch.cuda.is_available():
|
elif "cuda" in cur_device_type and torch.cuda.is_available():
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
else:
|
else:
|
||||||
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
return usage, cache, misc
|
||||||
|
|
||||||
|
|
||||||
|
def log_gpu_memory_usage(
|
||||||
|
log: logging.Logger | logging.LoggerAdapter,
|
||||||
|
msg: str = "",
|
||||||
|
device: int | torch.device = 0,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
active, allocated, reserved = get_gpu_memory_usage(device)
|
||||||
|
except ValueError:
|
||||||
|
# likely CPU, ignore
|
||||||
return
|
return
|
||||||
|
cur_device_type = str(get_device_type())
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if allocated > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{allocated:.03f}GB allocated")
|
||||||
if misc > 0:
|
if reserved > 0:
|
||||||
extras.append(f"+{misc:.03f}GB misc")
|
extras.append(f"+{reserved:.03f}GB reserved")
|
||||||
msg = f"{cur_device_type} memory usage:" if not msg else msg
|
msg = f"{cur_device_type} memory active:" if not msg else msg
|
||||||
log.info(
|
log.debug(
|
||||||
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
|
f"{msg} {active:.03f}GB ({', '.join(extras)})",
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from transformers.trainer_utils import (
|
|||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
@@ -100,7 +100,6 @@ class GPUStatsCallback(
|
|||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.logged = False
|
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
self,
|
self,
|
||||||
@@ -109,9 +108,21 @@ class GPUStatsCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> TrainerControl:
|
) -> TrainerControl:
|
||||||
if not self.logged and state.global_step > 1:
|
if state.global_step > 0:
|
||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
if self.cfg.use_wandb and state.is_world_process_zero:
|
||||||
self.logged = True
|
try:
|
||||||
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"memory/max_memory_active": active,
|
||||||
|
"memory/max_memory_allocated": allocated,
|
||||||
|
"memory/device_memory_reserved": reserved,
|
||||||
|
},
|
||||||
|
step=state.global_step,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
log_gpu_memory_usage(LOG, "", self.cfg.device)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import inspect
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from accelerate import PartialState
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
|
|||||||
|
|
||||||
from axolotl.monkeypatch.ring_attn import (
|
from axolotl.monkeypatch.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
register_ring_attn,
|
register_ring_attn_from_device_mesh,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
|
|||||||
if "num_items_in_batch" in batch:
|
if "num_items_in_batch" in batch:
|
||||||
# Approximation; this needed since num_items_in_batch may be counted across
|
# Approximation; this needed since num_items_in_batch may be counted across
|
||||||
# all samples in a gradient accumulated batch, not on a per-step basis.
|
# all samples in a gradient accumulated batch, not on a per-step basis.
|
||||||
|
local_valid_tokens = (batch["labels"] != -100).sum()
|
||||||
|
|
||||||
|
# All-reduce across sequence parallel ranks to get global token count
|
||||||
|
cp_group = get_ring_attn_group()
|
||||||
|
global_valid_tokens = local_valid_tokens.clone()
|
||||||
|
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
|
||||||
|
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
|
||||||
|
global_valid_tokens = int(global_valid_tokens.item())
|
||||||
|
|
||||||
batch["num_items_in_batch"] = (
|
batch["num_items_in_batch"] = (
|
||||||
batch["labels"] != -100
|
global_valid_tokens * gradient_accumulation_steps
|
||||||
).sum() * gradient_accumulation_steps
|
)
|
||||||
|
|
||||||
return batch, original_seq_len, pad_len
|
return batch, original_seq_len, pad_len
|
||||||
|
|
||||||
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
|
|||||||
Args:
|
Args:
|
||||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||||
hooks.
|
hooks.
|
||||||
sequence_parallel_degree: Number of processes to split sequences over.
|
context_parallel_size: Number of processes to split sequences over.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models: list[nn.Module],
|
models: list[nn.Module],
|
||||||
sequence_parallel_degree: int,
|
context_parallel_size: int,
|
||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
gather_outputs: bool,
|
gather_outputs: bool,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.context_parallel_size = context_parallel_size
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
def _register_ring_attn(self):
|
def _register_ring_attn(self):
|
||||||
# Initialize ring attn for sequence parallelism
|
# Initialize ring attn for sequence parallelism
|
||||||
register_ring_attn(
|
partial_state = PartialState()
|
||||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
register_ring_attn_from_device_mesh(
|
||||||
|
device_mesh=partial_state.device_mesh,
|
||||||
|
context_parallel_dim=("cp",),
|
||||||
heads_k_stride=self.heads_k_stride,
|
heads_k_stride=self.heads_k_stride,
|
||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
|
|||||||
num_shards=cfg.num_dataset_shards_to_save,
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
min_rows_per_proc = 256
|
||||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(
|
dataset.save_to_disk(
|
||||||
str(prepared_ds_path),
|
str(prepared_ds_path),
|
||||||
num_proc=min(max(1, len(dataset) // 8), num_workers),
|
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
|
||||||
max_shard_size=None,
|
max_shard_size=None,
|
||||||
num_shards=cfg.num_dataset_shards_to_save,
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,12 +2,15 @@
|
|||||||
utils to get GPU info for the current environment
|
utils to get GPU info for the current environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
from accelerate.utils.environment import (
|
from accelerate.utils.environment import (
|
||||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
)
|
)
|
||||||
from accelerate.utils.environment import (
|
from accelerate.utils.environment import (
|
||||||
get_gpu_info,
|
get_gpu_info,
|
||||||
)
|
)
|
||||||
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_p2p_ib_support():
|
def check_cuda_p2p_ib_support():
|
||||||
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
|
|||||||
except Exception: # pylint: disable=broad-except # nosec
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version(package: str) -> Version:
|
||||||
|
version_str = version(package)
|
||||||
|
return parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
|
def is_package_version_ge(package: str, version_: str) -> bool:
|
||||||
|
package_version = get_package_version(package)
|
||||||
|
return package_version >= parse(version_)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from multiprocessing import cpu_count, get_context
|
from multiprocessing import cpu_count, get_context
|
||||||
from typing import Iterable, Iterator, Union
|
from typing import Iterable, Iterator, Union
|
||||||
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
_sampled_lens = []
|
_sampled_lens = []
|
||||||
for _ in range(self.num_count_samples):
|
for _ in range(self.num_count_samples):
|
||||||
self._batches = None # Reset cached batches
|
self._batches = None # Reset cached batches
|
||||||
|
# log timer for generating batches
|
||||||
|
start_time = time.time()
|
||||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||||
|
LOG.debug(f"generate_batches time: {time.time() - start_time}")
|
||||||
len_batches = min(_sampled_lens)
|
len_batches = min(_sampled_lens)
|
||||||
|
|
||||||
# Gather minimum across all ranks
|
# Gather minimum across all ranks
|
||||||
|
|||||||
@@ -651,7 +651,23 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dp_shard_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dp_replicate_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Number of devices to replicate across."},
|
||||||
|
)
|
||||||
sequence_parallel_degree: int | None = Field(
|
sequence_parallel_degree: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Deprecated: use `context_parallel_size` instead"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
context_parallel_size: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||||
|
|||||||
@@ -673,7 +673,7 @@ class RLValidationMixin:
|
|||||||
data.get("rl") == "grpo"
|
data.get("rl") == "grpo"
|
||||||
and data.get("trl", {})
|
and data.get("trl", {})
|
||||||
and data.get("trl").get("use_liger_loss")
|
and data.get("trl").get("use_liger_loss")
|
||||||
and data.get("sequence_parallel_degree", 1) > 1
|
and data.get("context_parallel_size", 1) > 1
|
||||||
):
|
):
|
||||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||||
return data
|
return data
|
||||||
@@ -880,51 +880,35 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_fsdp_sharded_state_dict_w_safetensors(self):
|
|
||||||
if (
|
|
||||||
hasattr(self, "fsdp_config")
|
|
||||||
and self.fsdp_config
|
|
||||||
and hasattr(self, "save_safetensors")
|
|
||||||
and self.save_safetensors
|
|
||||||
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
|
|
||||||
and str(getattr(self, "fsdp_version", "1")) != "2"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||||
if not data.get("deepspeed"):
|
if data.get("deepspeed"):
|
||||||
raise ValueError(
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
ds_config = json.load(ds_fin)
|
||||||
)
|
should_save = False
|
||||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
if "tensor_parallel" not in ds_config:
|
||||||
ds_config = json.load(ds_fin)
|
ds_config["tensor_parallel"] = {
|
||||||
should_save = False
|
"autotp_size": tensor_parallel_size
|
||||||
if "tensor_parallel" not in ds_config:
|
}
|
||||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
should_save = True
|
||||||
should_save = True
|
if (
|
||||||
if (
|
|
||||||
"gather_16bit_weights_on_model_save"
|
|
||||||
not in ds_config["zero_optimization"]
|
|
||||||
):
|
|
||||||
ds_config["zero_optimization"][
|
|
||||||
"gather_16bit_weights_on_model_save"
|
"gather_16bit_weights_on_model_save"
|
||||||
] = True
|
not in ds_config["zero_optimization"]
|
||||||
should_save = True
|
):
|
||||||
if should_save:
|
ds_config["zero_optimization"][
|
||||||
temp_dir = tempfile.mkdtemp()
|
"gather_16bit_weights_on_model_save"
|
||||||
with open(
|
] = True
|
||||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
should_save = True
|
||||||
) as ds_fout:
|
if should_save:
|
||||||
json.dump(ds_config, ds_fout, indent=4)
|
temp_dir = tempfile.mkdtemp()
|
||||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
with open(
|
||||||
|
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||||
|
) as ds_fout:
|
||||||
|
json.dump(ds_config, ds_fout, indent=4)
|
||||||
|
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -1205,13 +1189,18 @@ class ComplexValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_sequence_parallel_degree(self):
|
def check_context_parallel_size(self):
|
||||||
if not self.sequence_parallel_degree:
|
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||||
self.sequence_parallel_degree = 1
|
LOG.warning(
|
||||||
elif self.sequence_parallel_degree > 1:
|
"`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
|
||||||
|
)
|
||||||
|
self.context_parallel_size = self.sequence_parallel_degree
|
||||||
|
if not self.context_parallel_size:
|
||||||
|
self.context_parallel_size = 1
|
||||||
|
elif self.context_parallel_size > 1:
|
||||||
if not self.flash_attention:
|
if not self.flash_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with context_parallel_size > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
@@ -1221,17 +1210,23 @@ class ComplexValidationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
|
||||||
|
transformers.modeling_flash_attention_utils._flash_supports_window
|
||||||
|
)
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Sequence parallelism (SP) is enabled with "
|
"Sequence parallelism (SP) is enabled with "
|
||||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
f"context_parallel_size={self.context_parallel_size}. "
|
||||||
"Please note that logged losses may differ slightly to the non-SP "
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
"losses due to transformers Trainer implementation details. "
|
"losses due to transformers Trainer implementation details. "
|
||||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
@@ -1242,7 +1237,7 @@ class ComplexValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_ring_attn_func(self):
|
def validate_ring_attn_func(self):
|
||||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
if getattr(self, "context_parallel_size", 1) == 1:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.ring_attn_func is not None:
|
if self.ring_attn_func is not None:
|
||||||
@@ -1259,6 +1254,20 @@ class ComplexValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedValidationMixin:
|
||||||
|
"""validation for distributed training."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tensor_parallel_optimizer(self):
|
||||||
|
if self.tensor_parallel_size > 1:
|
||||||
|
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
|
||||||
|
raise ValueError(
|
||||||
|
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
class ValidationMixin(
|
class ValidationMixin(
|
||||||
DatasetValidationMixin,
|
DatasetValidationMixin,
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.floor(
|
math.floor(
|
||||||
data_loader_len
|
data_loader_len
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.ceil(
|
math.ceil(
|
||||||
len(train_dataset)
|
len(train_dataset)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
|||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
"dataloader_prefetch_factor": 2,
|
"dataloader_prefetch_factor": 2,
|
||||||
"sequence_parallel_degree": 1,
|
"context_parallel_size": 1,
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
# Dtype
|
# Dtype
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
|||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"ring_attn_func": ring_attn_func,
|
"ring_attn_func": ring_attn_func,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
}
|
}
|
||||||
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
|
|||||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||||
# (False, 2, True, "batch_zigzag", 2.5),
|
# (False, 2, True, "batch_zigzag", 2.5),
|
||||||
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"sample_packing, varlen_llama3 ring_attn_func",
|
"sample_packing, varlen_llama3 ring_attn_func",
|
||||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sequence_parallel_training(
|
def test_sequence_parallel_training(
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
|
|||||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
|
@require_hopper
|
||||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
69
tests/e2e/multigpu/test_tp.py
Normal file
69
tests/e2e/multigpu/test_tp.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""multigpu e2e test for tensor parallelism."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorParallel:
|
||||||
|
"""Test class for Tensor Parallel functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="TP doesn't work with models with tied weights (embeddings)"
|
||||||
|
)
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_fft_sft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||||
|
)
|
||||||
@@ -1,481 +0,0 @@
|
|||||||
"""Tests for sequence parallelism functionality."""
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name,unused-argument
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import sys
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from accelerate.state import PartialState
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.ring_attn import (
|
|
||||||
get_ring_attn_group,
|
|
||||||
register_ring_attn,
|
|
||||||
set_ring_attn_group,
|
|
||||||
)
|
|
||||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def partial_state():
|
|
||||||
"""Create a real PartialState instance for testing."""
|
|
||||||
state = PartialState()
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-3,
|
|
||||||
"output_dir": "./model-out",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sequence_parallel_batch():
|
|
||||||
"""Create a test batch for sequence parallelism tests."""
|
|
||||||
batch_size = 1
|
|
||||||
seq_len = 8
|
|
||||||
|
|
||||||
# Create test tensors
|
|
||||||
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
|
|
||||||
attention_mask = torch.ones(batch_size, seq_len)
|
|
||||||
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
# Create test batch
|
|
||||||
batch = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
class TestRingAttention:
|
|
||||||
"""Tests for the ring attention functionality."""
|
|
||||||
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_get_ring_attn_group_no_registration(
|
|
||||||
self, mock_world_size, mock_rank, partial_state
|
|
||||||
):
|
|
||||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 4
|
|
||||||
mock_rank.return_value = 0
|
|
||||||
|
|
||||||
# Verify that RuntimeError is raised when no group is registered
|
|
||||||
with pytest.raises(
|
|
||||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
|
||||||
):
|
|
||||||
get_ring_attn_group()
|
|
||||||
|
|
||||||
@patch("torch.distributed.new_group")
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_register_ring_attn(
|
|
||||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
|
||||||
):
|
|
||||||
"""Test that ring attention groups are created correctly."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 8 # 8 GPUs total
|
|
||||||
mock_rank.return_value = 3 # GPU #3
|
|
||||||
mock_group = MagicMock()
|
|
||||||
mock_new_group.return_value = mock_group
|
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
|
||||||
register_ring_attn(
|
|
||||||
sequence_parallel_degree=4,
|
|
||||||
heads_k_stride=1,
|
|
||||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
|
||||||
assert mock_new_group.call_count == 2
|
|
||||||
|
|
||||||
# Verify that new_group was called
|
|
||||||
mock_new_group.assert_called()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
set_ring_attn_group(None)
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigValidation:
|
|
||||||
"""Tests for validating sequence parallelism configurations."""
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_mocks(self, monkeypatch):
|
|
||||||
"""Set up mocks for all tests in this class."""
|
|
||||||
# Mock the ring_flash_attn module
|
|
||||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_cfg(self):
|
|
||||||
"""Create a base configuration for testing."""
|
|
||||||
return DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-3,
|
|
||||||
"output_dir": "./model-out",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"config_updates, expected_values, should_pass, error_msg",
|
|
||||||
[
|
|
||||||
# Valid configuration
|
|
||||||
(
|
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
|
||||||
True,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
# Default sequence_parallel_degree
|
|
||||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
|
||||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
|
||||||
(
|
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
"flash_attention: true must be set",
|
|
||||||
),
|
|
||||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": True,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
"micro_batch_size must be set to 1",
|
|
||||||
),
|
|
||||||
# Valid: Basic GRPO config
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"trl": {"use_liger_loss": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"trl": TRLConfig(use_liger_loss=True),
|
|
||||||
},
|
|
||||||
True,
|
|
||||||
"GRPO + SP + Liger not currently supported",
|
|
||||||
),
|
|
||||||
# Invalid: GRPO config with Liger loss
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"rl": "grpo",
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"trl": {"use_liger_loss": True},
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
"GRPO + SP + Liger not currently supported",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"valid_config",
|
|
||||||
"default_sp_degree",
|
|
||||||
"without_flash_attention",
|
|
||||||
"sample_packing_with_large_batch",
|
|
||||||
"valid_grpo",
|
|
||||||
"grpo_with_liger_loss",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sequence_parallel_config_validation(
|
|
||||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
|
||||||
):
|
|
||||||
"""Test various sequence parallelism configuration scenarios."""
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Apply updates to base config
|
|
||||||
cfg = base_cfg
|
|
||||||
cfg.update(config_updates)
|
|
||||||
|
|
||||||
if should_pass:
|
|
||||||
# Should validate without errors
|
|
||||||
config = AxolotlInputConfig(**cfg)
|
|
||||||
|
|
||||||
# Check expected values
|
|
||||||
for key, value in expected_values.items():
|
|
||||||
assert getattr(config, key) == value
|
|
||||||
else:
|
|
||||||
# Should raise exception
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
AxolotlInputConfig(**cfg)
|
|
||||||
assert error_msg in str(excinfo.value)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ring_attn_func, sample_packing, expected_func",
|
|
||||||
[
|
|
||||||
(None, True, RingAttnFunc.VARLEN_LLAMA3),
|
|
||||||
(None, False, RingAttnFunc.BATCH_RING),
|
|
||||||
],
|
|
||||||
ids=["default_with_sample_packing", "default_without_sample_packing"],
|
|
||||||
)
|
|
||||||
def test_ring_attn_func_validation(
|
|
||||||
self, base_cfg, ring_attn_func, sample_packing, expected_func
|
|
||||||
):
|
|
||||||
"""Test ring_attn_func validation and defaults."""
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Apply updates to base config
|
|
||||||
cfg = base_cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ring_attn_func is not None:
|
|
||||||
cfg["ring_attn_func"] = ring_attn_func
|
|
||||||
|
|
||||||
# Should validate without errors
|
|
||||||
config = AxolotlInputConfig(**cfg)
|
|
||||||
|
|
||||||
# Check ring_attn_func value
|
|
||||||
assert config.ring_attn_func.value == expected_func
|
|
||||||
|
|
||||||
def test_invalid_ring_attn_func(self, base_cfg):
|
|
||||||
"""Test that an invalid ring_attn_func is rejected."""
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Invalid configuration with invalid ring_attn_func
|
|
||||||
cfg = base_cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
"ring_attn_func": "INVALID_FUNC",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should raise ValidationError
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
AxolotlInputConfig(**cfg)
|
|
||||||
|
|
||||||
# Verify error message
|
|
||||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
class TestApplySequenceParallelism:
|
|
||||||
"""Tests for the apply_sequence_parallelism function."""
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def mock_distributed(self, monkeypatch):
|
|
||||||
"""Mock torch.distributed functions for testing."""
|
|
||||||
# Mock is_initialized to return True
|
|
||||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
|
||||||
|
|
||||||
# Mock get_rank to return 0 by default
|
|
||||||
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
|
|
||||||
|
|
||||||
# Mock get_world_size to return 2 by default
|
|
||||||
monkeypatch.setattr(
|
|
||||||
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the process group
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
|
||||||
MagicMock,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock update_ring_attn_params
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
|
||||||
lambda **kwargs: None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
|
||||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
|
||||||
"""Test that function returns original batch when world size is 1."""
|
|
||||||
mock_get_ring_attn_group.return_value = 0
|
|
||||||
|
|
||||||
result, _, _ = apply_sequence_parallelism(
|
|
||||||
batch=sequence_parallel_batch,
|
|
||||||
local_rank=0,
|
|
||||||
local_world_size=1,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return the original batch unchanged
|
|
||||||
assert result == sequence_parallel_batch
|
|
||||||
|
|
||||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
|
||||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
|
||||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
|
||||||
mock_get_ring_attn_group.return_value = 0
|
|
||||||
|
|
||||||
batch = sequence_parallel_batch
|
|
||||||
seq_len = batch["input_ids"].size(1)
|
|
||||||
|
|
||||||
result, _, _ = apply_sequence_parallelism(
|
|
||||||
batch=batch,
|
|
||||||
local_rank=0,
|
|
||||||
local_world_size=2,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that sequence dimension was sharded correctly
|
|
||||||
assert result["input_ids"].shape[1] == seq_len // 2
|
|
||||||
assert result["attention_mask"].shape[1] == seq_len // 2
|
|
||||||
|
|
||||||
# Verify content: rank 0 should get the first half of the sequence
|
|
||||||
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
|
|
||||||
assert torch.equal(
|
|
||||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
|
||||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
|
||||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
|
||||||
mock_get_ring_attn_group.return_value = 0
|
|
||||||
|
|
||||||
batch = sequence_parallel_batch
|
|
||||||
seq_len = batch["input_ids"].size(1)
|
|
||||||
original_input_ids = batch["input_ids"].clone()
|
|
||||||
|
|
||||||
result, _, _ = apply_sequence_parallelism(
|
|
||||||
batch=batch,
|
|
||||||
local_rank=1,
|
|
||||||
local_world_size=2,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify content: rank 1 should get the second half of the sequence
|
|
||||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
|
||||||
|
|
||||||
# TODO(djsaunde): add back once implemented.
|
|
||||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
|
||||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
|
||||||
# batch = sequence_parallel_batch
|
|
||||||
# original_input_ids = batch["input_ids"].clone()
|
|
||||||
# seq_len = batch["input_ids"].size(1)
|
|
||||||
|
|
||||||
# # Test rank 0
|
|
||||||
# result_rank0 = apply_sequence_parallelism(
|
|
||||||
# batch={k: v.clone() for k, v in batch.items()},
|
|
||||||
# local_rank=0,
|
|
||||||
# local_world_size=2,
|
|
||||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Test rank 1
|
|
||||||
# result_rank1 = apply_sequence_parallelism(
|
|
||||||
# batch={k: v.clone() for k, v in batch.items()},
|
|
||||||
# local_rank=1,
|
|
||||||
# local_world_size=2,
|
|
||||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Checks for both ranks
|
|
||||||
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
|
||||||
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
|
||||||
|
|
||||||
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
|
|
||||||
# # Rank 0 should get chunks [0, 1] and [6, 7]
|
|
||||||
# # Rank 1 should get chunks [2, 3] and [4, 5]
|
|
||||||
# if seq_len == 8:
|
|
||||||
# # Create expected tensors for comparison
|
|
||||||
# rank0_expected = torch.cat(
|
|
||||||
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
|
||||||
# )
|
|
||||||
|
|
||||||
# rank1_expected = torch.cat(
|
|
||||||
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
|
||||||
# )
|
|
||||||
|
|
||||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
|
||||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
|
||||||
|
|
||||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
|
||||||
def test_partial_application(
|
|
||||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
|
||||||
):
|
|
||||||
"""Test that we can create a partially applied version of the function."""
|
|
||||||
mock_get_ring_attn_group.return_value = 0
|
|
||||||
|
|
||||||
batch = sequence_parallel_batch
|
|
||||||
original_input_ids = batch["input_ids"].clone()
|
|
||||||
|
|
||||||
# Create a partially applied function
|
|
||||||
rank0_ring_parallel = functools.partial(
|
|
||||||
apply_sequence_parallelism,
|
|
||||||
local_rank=0,
|
|
||||||
local_world_size=2,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the partially applied function
|
|
||||||
result, _, _ = rank0_ring_parallel(batch=batch)
|
|
||||||
|
|
||||||
# Verify it works as expected
|
|
||||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
|
||||||
assert torch.equal(
|
|
||||||
result["input_ids"],
|
|
||||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
|
||||||
"""Test handling of batch without position_ids."""
|
|
||||||
# Create a batch without position_ids
|
|
||||||
batch = {
|
|
||||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
|
||||||
}
|
|
||||||
original_input_ids = batch["input_ids"].clone()
|
|
||||||
|
|
||||||
# This should run without error even though position_ids is missing
|
|
||||||
result, _, _ = apply_sequence_parallelism(
|
|
||||||
batch=batch,
|
|
||||||
local_rank=0,
|
|
||||||
local_world_size=2,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verification should pass
|
|
||||||
assert "position_ids" in result
|
|
||||||
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
|
|
||||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
|
||||||
@@ -52,6 +52,8 @@ class TestLoadModelUtils:
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"context_parallel_size": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
|||||||
@@ -142,6 +142,10 @@ def is_hopper():
|
|||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def require_hopper(test_case):
|
||||||
|
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def check_tensorboard(
|
def check_tensorboard(
|
||||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -171,3 +171,44 @@ class TestModelsUtils:
|
|||||||
message_property_mappings={"content": "different_content"},
|
message_property_mappings={"content": "different_content"},
|
||||||
)
|
)
|
||||||
assert "Conflicting message content fields" in str(exc_info.value)
|
assert "Conflicting message content fields" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
|
||||||
|
[
|
||||||
|
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
|
||||||
|
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
|
||||||
|
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
|
||||||
|
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
|
||||||
|
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
|
||||||
|
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_parallel_config_kwargs(
|
||||||
|
self,
|
||||||
|
world_size,
|
||||||
|
tensor_parallel_size,
|
||||||
|
context_parallel_size,
|
||||||
|
dp_shard_size,
|
||||||
|
dp_replicate_size,
|
||||||
|
is_fsdp,
|
||||||
|
expected,
|
||||||
|
):
|
||||||
|
res = (
|
||||||
|
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||||
|
world_size,
|
||||||
|
tensor_parallel_size,
|
||||||
|
context_parallel_size,
|
||||||
|
dp_shard_size,
|
||||||
|
dp_replicate_size,
|
||||||
|
is_fsdp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected[0] > 1:
|
||||||
|
assert res["tp_size"] == expected[0]
|
||||||
|
if expected[1] > 1:
|
||||||
|
assert res["cp_size"] == expected[1]
|
||||||
|
if expected[2] > 1:
|
||||||
|
assert res["dp_shard_size"] == expected[2]
|
||||||
|
if expected[3] > 1:
|
||||||
|
assert res["dp_replicate_size"] == expected[3]
|
||||||
|
|||||||
@@ -26,32 +26,6 @@ class TestFSDPValidation:
|
|||||||
assert cfg.fsdp_version == 2
|
assert cfg.fsdp_version == 2
|
||||||
assert cfg.fsdp_config.fsdp_version is None
|
assert cfg.fsdp_config.fsdp_version is None
|
||||||
|
|
||||||
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
|
||||||
fsdp_config={
|
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
|
||||||
},
|
|
||||||
save_safetensors=True,
|
|
||||||
)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
# test w/o prefix too
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
|
||||||
fsdp_config={
|
|
||||||
"state_dict_type": "SHARDED_STATE_DICT",
|
|
||||||
},
|
|
||||||
save_safetensors=True,
|
|
||||||
)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||||
cfg = min_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
|
|||||||
Reference in New Issue
Block a user