Compare commits
4 Commits
nd_paralle
...
testingci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e36d3c9f30 | ||
|
|
53614391ed | ||
|
|
1407aac779 | ||
|
|
b34c3371ed |
4
.github/workflows/base.yml
vendored
4
.github/workflows/base.yml
vendored
@@ -17,7 +17,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||
build-base-uv:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
strategy:
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -3,6 +3,7 @@ on:
|
||||
# check on PRs, and manual triggers
|
||||
merge_group:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
@@ -16,6 +17,7 @@ jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
|
||||
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,7 +21,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
3
.github/workflows/preview-docs.yml
vendored
3
.github/workflows/preview-docs.yml
vendored
@@ -2,7 +2,7 @@ name: Preview
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
@@ -25,6 +25,7 @@ permissions:
|
||||
jobs:
|
||||
preview:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -13,6 +13,7 @@ on:
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
@@ -34,6 +35,7 @@ jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
@@ -47,6 +49,7 @@ jobs:
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
# needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -121,6 +124,7 @@ jobs:
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -185,7 +189,7 @@ jobs:
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
@@ -235,7 +239,7 @@ jobs:
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
@@ -289,6 +293,7 @@ jobs:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [docker-e2e-tests]
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -13,9 +13,9 @@ packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@82603b6cc284dbdf2b7a7cf070feb6a2c3bb53cf
|
||||
transformers==4.53.2
|
||||
tokenizers>=0.21.1
|
||||
accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config
|
||||
accelerate==1.9.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.19.1
|
||||
@@ -62,7 +62,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.10.0
|
||||
torchao==0.12.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
|
||||
@@ -70,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
context_parallel_size=None,
|
||||
sequence_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
|
||||
@@ -27,7 +27,6 @@ import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
@@ -435,18 +434,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
use_configured_state = True
|
||||
if self.cfg.accelerator_config:
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=True,
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
|
||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.context_parallel_size > 1:
|
||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
sequence_parallel_degree: int | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.num_sp_groups = world_size // context_parallel_size
|
||||
self.sp_group_id = rank // context_parallel_size
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
@@ -100,7 +100,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
@@ -130,7 +130,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
@@ -167,9 +167,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.context_parallel_size,
|
||||
// self.args.sequence_parallel_degree,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
context_parallel_size=self.args.context_parallel_size,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -235,7 +235,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.context_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -308,18 +308,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.context_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
context_parallel_size = self.args.context_parallel_size
|
||||
num_sp_groups = world_size // context_parallel_size
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * context_parallel_size
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
@@ -335,7 +335,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.context_parallel_size
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -352,14 +352,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.context_parallel_size
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.context_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
@@ -583,7 +583,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.context_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
|
||||
@@ -13,11 +13,9 @@ class CheckpointSaveMixin(Trainer):
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except (NotImplementedError, KeyError) as exc:
|
||||
# TODO: fix fsdp2 optimizer saving
|
||||
LOG.warning_once(
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
"for this training run will not be possible."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -28,13 +30,13 @@ class LigerArgs(BaseModel):
|
||||
Input args for LIGER.
|
||||
"""
|
||||
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
liger_cross_entropy: bool | None = None
|
||||
liger_fused_linear_cross_entropy: bool | None = None
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -60,13 +62,3 @@ class LigerArgs(BaseModel):
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_liger_rms_norm_tensor_parallel(cls, data):
|
||||
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
||||
raise ValueError(
|
||||
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
||||
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -13,8 +13,7 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate.utils.dataclasses import ParallelismConfig
|
||||
from accelerate import init_empty_weights
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
PeftMixedModel,
|
||||
@@ -52,7 +51,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
get_world_size,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
@@ -184,7 +182,6 @@ class ModelLoader:
|
||||
|
||||
def _apply_pre_model_load_setup(self):
|
||||
"""Apply patches and setup configurations before model loading."""
|
||||
self._set_parallel_config()
|
||||
self._set_auto_model_loader()
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
@@ -392,52 +389,6 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
dp_replicate_size = get_world_size()
|
||||
pc_kwargs = {}
|
||||
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
||||
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
||||
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
||||
if dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
|
||||
parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
mesh_dim_names, mesh_shape = parallelism_config.get_mesh()
|
||||
device_mesh = torch.distributed.init_device_mesh(
|
||||
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
submeshes = [
|
||||
tuple(parallelism_config.dp_dim_names),
|
||||
tuple(parallelism_config.dp_shard_cp_dim_names),
|
||||
tuple(parallelism_config.dp_cp_dim_names),
|
||||
]
|
||||
submesh_names = [
|
||||
# create a submesh which is only used for distributing data across data parallel dims (no comms)
|
||||
"dp",
|
||||
# create a submesh which is used *just* for FSDP parameter gathering/scattering
|
||||
# and gradients reduce-scattering
|
||||
"dp_shard_cp",
|
||||
# create a submesh which is used for correctly reducing loss across data replica/context parallel
|
||||
"dp_cp",
|
||||
]
|
||||
for submesh, submesh_name in zip(submeshes, submesh_names):
|
||||
if submesh:
|
||||
device_mesh[submesh]._flatten( # pylint: disable=protected-access
|
||||
submesh_name
|
||||
)
|
||||
|
||||
PartialState().parallelism_config = parallelism_config
|
||||
PartialState().device_mesh = device_mesh
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||
@@ -670,14 +621,6 @@ class ModelLoader:
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
|
||||
@@ -261,14 +261,14 @@ class PatchManager:
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.context_parallel_size, self.cfg.fsdp)
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
"""
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(
|
||||
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
|
||||
):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`):
|
||||
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
from torch.distributed.tensor import distribute_tensor
|
||||
|
||||
LOG.info("Broadcasting full state dict to all ranks...")
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
|
||||
return model
|
||||
|
||||
|
||||
def get_state_dict(self, model, unwrap=True):
|
||||
"""
|
||||
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
|
||||
precision.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
A PyTorch model sent through [`Accelerator.prepare`]
|
||||
unwrap (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
|
||||
|
||||
Returns:
|
||||
`dict`: The state dictionary of the model potentially without full precision.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> net = torch.nn.Linear(2, 2)
|
||||
>>> net = accelerator.prepare(net)
|
||||
>>> state_dict = accelerator.get_state_dict(net)
|
||||
```
|
||||
"""
|
||||
from accelerate import DistributedType
|
||||
from accelerate.utils import compare_versions
|
||||
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
||||
tp_sharding = (
|
||||
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
|
||||
)
|
||||
if zero3_sharding or tp_sharding:
|
||||
if model.zero_gather_16bit_weights_on_model_save():
|
||||
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
|
||||
raise ImportError(
|
||||
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
|
||||
)
|
||||
state_dict = (
|
||||
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
if tp_sharding
|
||||
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
else:
|
||||
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
||||
|
||||
state_dict = clone_tensors_for_torch_save(
|
||||
self.unwrap_model(model).state_dict()
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
if unwrap:
|
||||
model = self.unwrap_model(model)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||
if module.base_layer.bias is not None:
|
||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||
log_bias_dtype_mismatch = True
|
||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||
module.base_layer.weight.dtype
|
||||
)
|
||||
|
||||
for active_adapter in module.active_adapters:
|
||||
if module.lora_A:
|
||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_B:
|
||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_A:
|
||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_B:
|
||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_magnitude_vector:
|
||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||
return log_bias_dtype_mismatch
|
||||
|
||||
|
||||
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
||||
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`): The model to prepare
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: Prepared model
|
||||
"""
|
||||
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
|
||||
from accelerate.utils.modeling import get_non_persistent_buffers
|
||||
from peft import PeftModel
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
|
||||
is_type_fsdp = isinstance(model, FSDPModule) or (
|
||||
is_compiled_module(model)
|
||||
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
|
||||
)
|
||||
if is_type_fsdp:
|
||||
return model
|
||||
|
||||
fsdp2_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
original_sd = model.state_dict()
|
||||
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "transformer"
|
||||
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "size"
|
||||
|
||||
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
|
||||
# This is because of `apply_activation_checkpointing` which will can reuse this function
|
||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
if fsdp2_plugin.activation_checkpointing:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl,
|
||||
apply_activation_checkpointing,
|
||||
checkpoint_wrapper,
|
||||
)
|
||||
|
||||
# Apply activation checkpointing before applying `fully_shard`
|
||||
apply_activation_checkpointing(
|
||||
model,
|
||||
checkpoint_wrapper_fn=functools.partial(
|
||||
checkpoint_wrapper,
|
||||
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||
),
|
||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||
)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
for _, param in model.named_parameters():
|
||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
||||
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
model_has_params4bit = True
|
||||
break
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||
|
||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||
# Also, these buffers aren't getting sharded by default
|
||||
# We get the FQNs of all non-persistent buffers, to re-register them after
|
||||
non_persistent_buffer_fqns = get_non_persistent_buffers(
|
||||
model, recurse=True, fqns=True
|
||||
)
|
||||
original_non_persistent_buffers = copy.deepcopy(
|
||||
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
||||
)
|
||||
# We move the model to meta device, as then sharding happens on meta device
|
||||
model = model.to(torch.device("meta"))
|
||||
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
||||
# We assume `transformers` models have a `tie_weights` method if they support it
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
for module in get_module_children_bottom_up(model)[:-1]:
|
||||
if is_peft_model and isinstance(module, LoraLayer):
|
||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||
module, fsdp2_kwargs
|
||||
)
|
||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||
fully_shard(module, **fsdp2_kwargs)
|
||||
|
||||
fully_shard(model, **fsdp2_kwargs)
|
||||
|
||||
if log_bias_dtype_mismatch:
|
||||
LOG.warning(
|
||||
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# We re-register the buffers, as they may not be in the state_dict
|
||||
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
||||
buffer_tensor = buffer_tensor.to(accelerator.device)
|
||||
|
||||
if "." in fqn:
|
||||
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
local_buffer_name = fqn
|
||||
parent_module = model
|
||||
|
||||
parent_module.register_buffer(
|
||||
local_buffer_name, buffer_tensor, persistent=False
|
||||
)
|
||||
|
||||
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
||||
# Needs to be called both here and above
|
||||
# removing this call makes the have slightly different loss
|
||||
# removing the call above leads to extra memory usage as explained in the comment above
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
|
||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate"],
|
||||
"Accelerator.get_state_dict",
|
||||
get_state_dict,
|
||||
)
|
||||
|
||||
@@ -162,14 +162,14 @@ def create_ring_flash_attention_forward(
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
context_parallel_size: int,
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
context_parallel_size: Sequence parallelism factor.
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
@@ -182,25 +182,25 @@ def register_ring_attn(
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {context_parallel_size} GPUs"
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
assert context_parallel_size <= world_size, (
|
||||
f"context_parallel_size ({context_parallel_size}) "
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % context_parallel_size == 0, (
|
||||
f"context_parallel_size ({context_parallel_size}) "
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // context_parallel_size):
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * context_parallel_size,
|
||||
(i + 1) * context_parallel_size,
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
@@ -299,12 +299,12 @@ def patch_prepare_data_loader():
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
context_parallel_size: The degree of sequence parallelism to use.
|
||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||
fsdp: Whether to use FSDP.
|
||||
"""
|
||||
|
||||
@@ -323,8 +323,8 @@ def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
||||
# Create device mesh with sequence parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // context_parallel_size,
|
||||
context_parallel_size,
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
@@ -344,5 +344,5 @@ def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with context_parallel_size={context_parallel_size}"
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
)
|
||||
|
||||
@@ -202,7 +202,7 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.context_parallel_size > 1:
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
@@ -210,7 +210,7 @@ def execute_training(
|
||||
stack.enter_context(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
context_parallel_size=cfg.context_parallel_size,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
|
||||
@@ -167,7 +167,7 @@ class SequenceParallelContextManager:
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
context_parallel_size: Number of processes to split sequences over.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
@@ -179,14 +179,14 @@ class SequenceParallelContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
context_parallel_size: int,
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
):
|
||||
self.models = models
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
@@ -231,7 +231,7 @@ class SequenceParallelContextManager:
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
context_parallel_size=self.context_parallel_size,
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
@@ -644,19 +644,7 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
dp_shard_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||
},
|
||||
)
|
||||
sequence_parallel_degree: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Deprecated: use `context_parallel_size` instead"
|
||||
},
|
||||
)
|
||||
context_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||
|
||||
@@ -686,7 +686,7 @@ class RLValidationMixin:
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("context_parallel_size", 1) > 1
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
@@ -913,30 +913,31 @@ class OptimizationValidationMixin:
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||
if data.get("deepspeed"):
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {
|
||||
"autotp_size": tensor_parallel_size
|
||||
}
|
||||
should_save = True
|
||||
if (
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||
)
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
@@ -1234,13 +1235,13 @@ class ComplexValidationMixin:
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_size(self):
|
||||
if not self.context_parallel_size:
|
||||
self.context_parallel_size = 1
|
||||
elif self.context_parallel_size > 1:
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with context_parallel_size > 1"
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1253,14 +1254,14 @@ class ComplexValidationMixin:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"context_parallel_size={self.context_parallel_size}. "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1271,7 +1272,7 @@ class ComplexValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "context_parallel_size", 1) == 1:
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
|
||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
)
|
||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"context_parallel_size": 1,
|
||||
"sequence_parallel_degree": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
|
||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestRingAttention:
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
context_parallel_size=4,
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
@@ -156,24 +156,24 @@ class TestConfigValidation:
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"context_parallel_size": 2, "flash_attention": True},
|
||||
{"context_parallel_size": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default context_parallel_size
|
||||
({}, {"context_parallel_size": 1}, True, None),
|
||||
# Invalid: context_parallel_size > 1 without flash_attention
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"context_parallel_size": 2, "flash_attention": False},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
@@ -186,13 +186,13 @@ class TestConfigValidation:
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
@@ -204,7 +204,7 @@ class TestConfigValidation:
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
@@ -262,7 +262,7 @@ class TestConfigValidation:
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
@@ -282,7 +282,7 @@ class TestConfigValidation:
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user