Compare commits
4 Commits
nd_paralle
...
testingci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e36d3c9f30 | ||
|
|
53614391ed | ||
|
|
1407aac779 | ||
|
|
b34c3371ed |
4
.github/workflows/base.yml
vendored
4
.github/workflows/base.yml
vendored
@@ -17,7 +17,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
@@ -108,7 +108,7 @@ jobs:
|
|||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||||
build-base-uv:
|
build-base-uv:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -3,6 +3,7 @@ on:
|
|||||||
# check on PRs, and manual triggers
|
# check on PRs, and manual triggers
|
||||||
merge_group:
|
merge_group:
|
||||||
pull_request:
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -16,6 +17,7 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
|
|||||||
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,7 +21,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
3
.github/workflows/preview-docs.yml
vendored
3
.github/workflows/preview-docs.yml
vendored
@@ -2,7 +2,7 @@ name: Preview
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
# Run the workflow only when one of these files changes
|
# Run the workflow only when one of these files changes
|
||||||
paths:
|
paths:
|
||||||
@@ -25,6 +25,7 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
preview:
|
preview:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -13,6 +13,7 @@ on:
|
|||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
- 'cicd/Dockerfile.jinja'
|
- 'cicd/Dockerfile.jinja'
|
||||||
pull_request:
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -34,6 +35,7 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
@@ -47,6 +49,7 @@ jobs:
|
|||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
# needs: [preload-cache]
|
# needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -121,6 +124,7 @@ jobs:
|
|||||||
pytest-sdist:
|
pytest-sdist:
|
||||||
name: PyTest from Source Dist
|
name: PyTest from Source Dist
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -185,7 +189,7 @@ jobs:
|
|||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
# Run this job first as a gate for running the remainder of the test matrix
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -235,7 +239,7 @@ jobs:
|
|||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -289,6 +293,7 @@ jobs:
|
|||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [docker-e2e-tests]
|
needs: [docker-e2e-tests]
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
context_parallel_size: 4 # Split sequences across 4 GPUs
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
|||||||
ring_attn_func:
|
ring_attn_func:
|
||||||
```
|
```
|
||||||
|
|
||||||
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
- With 4 GPUs, valid values would be 2 or 4
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
|||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
|||||||
|
|
||||||
## Effect on Batch Size
|
## Effect on Batch Size
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
- The number of batches processed per step decreases
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.16.0
|
peft==0.16.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@82603b6cc284dbdf2b7a7cf070feb6a2c3bb53cf
|
transformers==4.53.2
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config
|
accelerate==1.9.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.19.1
|
trl==0.19.1
|
||||||
@@ -62,7 +62,7 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.10.0
|
torchao==0.12.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
context_parallel_size=None,
|
sequence_parallel_degree=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
fsdp=None,
|
fsdp=None,
|
||||||
fsdp_config=None,
|
fsdp_config=None,
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -435,18 +434,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
use_configured_state = True
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
use_configured_state = self.cfg.accelerator_config.pop(
|
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||||
"use_configured_state", use_configured_state
|
|
||||||
)
|
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
|
||||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
|
||||||
use_configured_state=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -82,8 +82,8 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
if cfg.context_parallel_size > 1:
|
if cfg.sequence_parallel_degree > 1:
|
||||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||||
|
|
||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|||||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
"""Axolotl GRPO Config for GRPO training"""
|
"""Axolotl GRPO Config for GRPO training"""
|
||||||
|
|
||||||
context_parallel_size: int | None = None
|
sequence_parallel_degree: int | None = None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
- Data is properly distributed across SP groups.
|
- Data is properly distributed across SP groups.
|
||||||
|
|
||||||
In the table below, the values represent dataset indices. Each SP group has
|
In the table below, the values represent dataset indices. Each SP group has
|
||||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||||
|
|
||||||
Sequence Parallel Groups
|
Sequence Parallel Groups
|
||||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: Rank of current process.
|
rank: Rank of current process.
|
||||||
batch_size: Number of samples per batch.
|
batch_size: Number of samples per batch.
|
||||||
repeat_count: How many times to repeat the full sampling process.
|
repeat_count: How many times to repeat the full sampling process.
|
||||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||||
shuffle: Whether to shuffle the dataset.
|
shuffle: Whether to shuffle the dataset.
|
||||||
seed: Random seed for shuffling.
|
seed: Random seed for shuffling.
|
||||||
drop_last: Whether to drop the last incomplete batch.
|
drop_last: Whether to drop the last incomplete batch.
|
||||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: int,
|
rank: int,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
repeat_count: int = 1,
|
repeat_count: int = 1,
|
||||||
context_parallel_size: int = 1,
|
sequence_parallel_degree: int = 1,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
# Sequence parallelism parameters
|
# Sequence parallelism parameters
|
||||||
self.context_parallel_size = context_parallel_size
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
self.num_sp_groups = world_size // context_parallel_size
|
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||||
self.sp_group_id = rank // context_parallel_size
|
self.sp_group_id = rank // sequence_parallel_degree
|
||||||
|
|
||||||
# Adjust dataset size for distributed sampling
|
# Adjust dataset size for distributed sampling
|
||||||
self.num_samples = len(self.dataset)
|
self.num_samples = len(self.dataset)
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
# Get number of SP groups (number of processes divided by SP degree)
|
# Get number of SP groups (number of processes divided by SP degree)
|
||||||
num_processes = self.accelerator.num_processes
|
num_processes = self.accelerator.num_processes
|
||||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
# Calculate batch size per SP group (not per process)
|
# Calculate batch size per SP group (not per process)
|
||||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||||
@@ -130,7 +130,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
if self.num_generations not in possible_values:
|
if self.num_generations not in possible_values:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||||
f"must be evenly divisible by the number of generations per prompt "
|
f"must be evenly divisible by the number of generations per prompt "
|
||||||
f"({self.num_generations}). Given the current eval batch size, "
|
f"({self.num_generations}). Given the current eval batch size, "
|
||||||
@@ -167,9 +167,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
batch_size=effective_batch_size
|
batch_size=effective_batch_size
|
||||||
// self.num_generations
|
// self.num_generations
|
||||||
// self.args.context_parallel_size,
|
// self.args.sequence_parallel_degree,
|
||||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||||
context_parallel_size=self.args.context_parallel_size,
|
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=self.args.seed,
|
seed=self.args.seed,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
@@ -235,7 +235,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
# slice each batch along the sequence dimension).
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.context_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
# Otherwise prepare with accelerator
|
||||||
@@ -308,18 +308,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||||
all_prompts_text = gather_object(prompts_text)
|
all_prompts_text = gather_object(prompts_text)
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
if self.args.context_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
# Calculate sequence parallel group information
|
# Calculate sequence parallel group information
|
||||||
world_size = self.accelerator.num_processes
|
world_size = self.accelerator.num_processes
|
||||||
context_parallel_size = self.args.context_parallel_size
|
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||||
num_sp_groups = world_size // context_parallel_size
|
num_sp_groups = world_size // sequence_parallel_degree
|
||||||
|
|
||||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||||
# we only take one copy of each prompt from each SP group
|
# we only take one copy of each prompt from each SP group
|
||||||
ordered_set_of_prompts = []
|
ordered_set_of_prompts = []
|
||||||
for sp_group_id in range(num_sp_groups):
|
for sp_group_id in range(num_sp_groups):
|
||||||
# Get the first process from each SP group (typically the group leader)
|
# Get the first process from each SP group (typically the group leader)
|
||||||
group_leader_rank = sp_group_id * context_parallel_size
|
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||||
|
|
||||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||||
# We only need prompts from one rank in each SP group
|
# We only need prompts from one rank in each SP group
|
||||||
@@ -335,7 +335,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||||
# prompt individually.
|
# prompt individually.
|
||||||
ordered_set_of_prompts = all_prompts_text[
|
ordered_set_of_prompts = all_prompts_text[
|
||||||
:: self.num_generations * self.args.context_parallel_size
|
:: self.num_generations * self.args.sequence_parallel_degree
|
||||||
]
|
]
|
||||||
|
|
||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
@@ -352,14 +352,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_ids = [None] * (
|
completion_ids = [None] * (
|
||||||
len(all_prompts_text) // self.args.context_parallel_size
|
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
|
|
||||||
# Broadcast the completions from the main process to all processes
|
# Broadcast the completions from the main process to all processes
|
||||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||||
|
|
||||||
# Determine the appropriate slice based on sequence parallelism
|
# Determine the appropriate slice based on sequence parallelism
|
||||||
if self.args.context_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
@@ -583,7 +583,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||||
|
|
||||||
# Slice to keep only the local part of the data
|
# Slice to keep only the local part of the data
|
||||||
if self.args.context_parallel_size > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,9 @@ class CheckpointSaveMixin(Trainer):
|
|||||||
def _save_optimizer_and_scheduler(self, output_dir):
|
def _save_optimizer_and_scheduler(self, output_dir):
|
||||||
try:
|
try:
|
||||||
super()._save_optimizer_and_scheduler(output_dir)
|
super()._save_optimizer_and_scheduler(output_dir)
|
||||||
except (NotImplementedError, KeyError) as exc:
|
except NotImplementedError as exc:
|
||||||
# TODO: fix fsdp2 optimizer saving
|
LOG.warning(
|
||||||
LOG.warning_once(
|
|
||||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||||
"for this training run will not be possible.",
|
"for this training run will not be possible."
|
||||||
main_process_only=True,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -28,13 +30,13 @@ class LigerArgs(BaseModel):
|
|||||||
Input args for LIGER.
|
Input args for LIGER.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
liger_rope: bool | None = None
|
liger_rope: Optional[bool] = None
|
||||||
liger_rms_norm: bool | None = None
|
liger_rms_norm: Optional[bool] = None
|
||||||
liger_layer_norm: bool | None = None
|
liger_layer_norm: Optional[bool] = None
|
||||||
liger_swiglu: bool | None = None
|
liger_swiglu: Optional[bool] = None
|
||||||
liger_glu_activation: bool | None = None
|
liger_glu_activation: Optional[bool] = None
|
||||||
liger_cross_entropy: bool | None = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: bool | None = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -60,13 +62,3 @@ class LigerArgs(BaseModel):
|
|||||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_liger_rms_norm_tensor_parallel(cls, data):
|
|
||||||
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
|
||||||
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import PartialState, init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils.dataclasses import ParallelismConfig
|
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
PeftMixedModel,
|
PeftMixedModel,
|
||||||
@@ -52,7 +51,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
get_device_count,
|
get_device_count,
|
||||||
get_device_type,
|
get_device_type,
|
||||||
get_world_size,
|
|
||||||
)
|
)
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
@@ -184,7 +182,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _apply_pre_model_load_setup(self):
|
def _apply_pre_model_load_setup(self):
|
||||||
"""Apply patches and setup configurations before model loading."""
|
"""Apply patches and setup configurations before model loading."""
|
||||||
self._set_parallel_config()
|
|
||||||
self._set_auto_model_loader()
|
self._set_auto_model_loader()
|
||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
@@ -392,52 +389,6 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _set_parallel_config(self):
|
|
||||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
|
||||||
dp_replicate_size = get_world_size()
|
|
||||||
pc_kwargs = {}
|
|
||||||
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
|
|
||||||
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
|
||||||
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
|
|
||||||
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
|
||||||
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
|
||||||
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
|
||||||
if dp_replicate_size > 1:
|
|
||||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
|
||||||
|
|
||||||
parallelism_config = ParallelismConfig(
|
|
||||||
**pc_kwargs,
|
|
||||||
)
|
|
||||||
mesh_dim_names, mesh_shape = parallelism_config.get_mesh()
|
|
||||||
device_mesh = torch.distributed.init_device_mesh(
|
|
||||||
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
|
||||||
)
|
|
||||||
submeshes = [
|
|
||||||
tuple(parallelism_config.dp_dim_names),
|
|
||||||
tuple(parallelism_config.dp_shard_cp_dim_names),
|
|
||||||
tuple(parallelism_config.dp_cp_dim_names),
|
|
||||||
]
|
|
||||||
submesh_names = [
|
|
||||||
# create a submesh which is only used for distributing data across data parallel dims (no comms)
|
|
||||||
"dp",
|
|
||||||
# create a submesh which is used *just* for FSDP parameter gathering/scattering
|
|
||||||
# and gradients reduce-scattering
|
|
||||||
"dp_shard_cp",
|
|
||||||
# create a submesh which is used for correctly reducing loss across data replica/context parallel
|
|
||||||
"dp_cp",
|
|
||||||
]
|
|
||||||
for submesh, submesh_name in zip(submeshes, submesh_names):
|
|
||||||
if submesh:
|
|
||||||
device_mesh[submesh]._flatten( # pylint: disable=protected-access
|
|
||||||
submesh_name
|
|
||||||
)
|
|
||||||
|
|
||||||
PartialState().parallelism_config = parallelism_config
|
|
||||||
PartialState().device_mesh = device_mesh
|
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
@@ -670,14 +621,6 @@ class ModelLoader:
|
|||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
|
|
||||||
if self.cfg.tensor_parallel_size > 1:
|
|
||||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
|
||||||
self.model_kwargs["tp_plan"] = "auto"
|
|
||||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
|
||||||
if "device_map" in self.model_kwargs:
|
|
||||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|||||||
@@ -261,14 +261,14 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_sequence_parallel_patches(self):
|
def _apply_sequence_parallel_patches(self):
|
||||||
"""Apply sequence parallelism patches."""
|
"""Apply sequence parallelism patches."""
|
||||||
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
from axolotl.monkeypatch.ring_attn.patch import (
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
patch_prepare_data_loader,
|
patch_prepare_data_loader,
|
||||||
patch_prepare_device_mesh,
|
patch_prepare_device_mesh,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_prepare_data_loader()
|
patch_prepare_data_loader()
|
||||||
patch_prepare_device_mesh(self.cfg.context_parallel_size, self.cfg.fsdp)
|
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
|
|||||||
@@ -1,352 +0,0 @@
|
|||||||
"""
|
|
||||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import functools
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def fsdp2_load_full_state_dict(
|
|
||||||
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
|
||||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
|
||||||
Args:
|
|
||||||
accelerator (`Accelerator`): The accelerator instance
|
|
||||||
model (`torch.nn.Module`):
|
|
||||||
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
|
||||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
|
||||||
"""
|
|
||||||
from torch.distributed.tensor import distribute_tensor
|
|
||||||
|
|
||||||
LOG.info("Broadcasting full state dict to all ranks...")
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
meta_sharded_sd = model.state_dict()
|
|
||||||
sharded_sd = {}
|
|
||||||
for param_name, full_tensor in full_sd.items():
|
|
||||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
|
||||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
|
||||||
if hasattr(sharded_meta_param, "device_mesh"):
|
|
||||||
sharded_param = distribute_tensor(
|
|
||||||
full_tensor,
|
|
||||||
sharded_meta_param.device_mesh,
|
|
||||||
sharded_meta_param.placements,
|
|
||||||
src_data_rank=0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sharded_param = full_tensor
|
|
||||||
|
|
||||||
if offload_to_cpu:
|
|
||||||
sharded_param = sharded_param.cpu()
|
|
||||||
|
|
||||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
|
||||||
del full_tensor
|
|
||||||
full_sd[param_name] = None
|
|
||||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
|
||||||
end_time = time.time()
|
|
||||||
LOG.debug(
|
|
||||||
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
|
|
||||||
)
|
|
||||||
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(self, model, unwrap=True):
|
|
||||||
"""
|
|
||||||
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
|
|
||||||
precision.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (`torch.nn.Module`):
|
|
||||||
A PyTorch model sent through [`Accelerator.prepare`]
|
|
||||||
unwrap (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`dict`: The state dictionary of the model potentially without full precision.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> import torch
|
|
||||||
>>> from accelerate import Accelerator
|
|
||||||
|
|
||||||
>>> accelerator = Accelerator()
|
|
||||||
>>> net = torch.nn.Linear(2, 2)
|
|
||||||
>>> net = accelerator.prepare(net)
|
|
||||||
>>> state_dict = accelerator.get_state_dict(net)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
from accelerate import DistributedType
|
|
||||||
from accelerate.utils import compare_versions
|
|
||||||
|
|
||||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
|
||||||
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
|
||||||
tp_sharding = (
|
|
||||||
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
|
|
||||||
)
|
|
||||||
if zero3_sharding or tp_sharding:
|
|
||||||
if model.zero_gather_16bit_weights_on_model_save():
|
|
||||||
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
|
|
||||||
raise ImportError(
|
|
||||||
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
|
|
||||||
)
|
|
||||||
state_dict = (
|
|
||||||
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
|
|
||||||
if tp_sharding
|
|
||||||
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
|
||||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
|
||||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
|
||||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
||||||
|
|
||||||
state_dict = clone_tensors_for_torch_save(
|
|
||||||
self.unwrap_model(model).state_dict()
|
|
||||||
)
|
|
||||||
elif self.is_fsdp2:
|
|
||||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
|
||||||
state_dict = {}
|
|
||||||
sharded_state_dict = model.state_dict()
|
|
||||||
for param_name, param in sharded_state_dict.items():
|
|
||||||
if param.is_cpu:
|
|
||||||
param = param.to(torch.device("cuda"))
|
|
||||||
|
|
||||||
param = param.full_tensor()
|
|
||||||
if torch.distributed.get_rank() == 0:
|
|
||||||
state_dict[param_name] = param.cpu()
|
|
||||||
torch.distributed.barrier()
|
|
||||||
elif self.distributed_type == DistributedType.FSDP:
|
|
||||||
from torch.distributed.fsdp import FullStateDictConfig
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp import StateDictType
|
|
||||||
|
|
||||||
full_state_dict_config = FullStateDictConfig(
|
|
||||||
offload_to_cpu=True, rank0_only=True
|
|
||||||
)
|
|
||||||
with FSDP.state_dict_type(
|
|
||||||
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
|
|
||||||
):
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
else:
|
|
||||||
if unwrap:
|
|
||||||
model = self.unwrap_model(model)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
|
||||||
from torch.distributed.fsdp import fully_shard
|
|
||||||
|
|
||||||
log_bias_dtype_mismatch = False
|
|
||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
|
||||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
|
||||||
if module.base_layer.bias is not None:
|
|
||||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
|
||||||
log_bias_dtype_mismatch = True
|
|
||||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
|
||||||
module.base_layer.weight.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
for active_adapter in module.active_adapters:
|
|
||||||
if module.lora_A:
|
|
||||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_B:
|
|
||||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_embedding_A:
|
|
||||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_embedding_B:
|
|
||||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_magnitude_vector:
|
|
||||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
|
||||||
return log_bias_dtype_mismatch
|
|
||||||
|
|
||||||
|
|
||||||
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
accelerator (`Accelerator`): The accelerator instance
|
|
||||||
model (`torch.nn.Module`): The model to prepare
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.nn.Module`: Prepared model
|
|
||||||
"""
|
|
||||||
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
|
|
||||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
|
|
||||||
from accelerate.utils.modeling import get_non_persistent_buffers
|
|
||||||
from peft import PeftModel
|
|
||||||
from peft.tuners.lora import LoraLayer
|
|
||||||
from torch.distributed.fsdp import (
|
|
||||||
CPUOffloadPolicy,
|
|
||||||
FSDPModule,
|
|
||||||
MixedPrecisionPolicy,
|
|
||||||
fully_shard,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_type_fsdp = isinstance(model, FSDPModule) or (
|
|
||||||
is_compiled_module(model)
|
|
||||||
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
if is_type_fsdp:
|
|
||||||
return model
|
|
||||||
|
|
||||||
fsdp2_plugin = accelerator.state.fsdp_plugin
|
|
||||||
|
|
||||||
original_sd = model.state_dict()
|
|
||||||
|
|
||||||
from torch.distributed.fsdp.wrap import (
|
|
||||||
size_based_auto_wrap_policy,
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
|
||||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
|
||||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
|
||||||
pass # auto_wrap_policy_type = "transformer"
|
|
||||||
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
|
|
||||||
pass # auto_wrap_policy_type = "size"
|
|
||||||
|
|
||||||
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
|
|
||||||
# This is because of `apply_activation_checkpointing` which will can reuse this function
|
|
||||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
|
||||||
|
|
||||||
if fsdp2_plugin.activation_checkpointing:
|
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
||||||
CheckpointImpl,
|
|
||||||
apply_activation_checkpointing,
|
|
||||||
checkpoint_wrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply activation checkpointing before applying `fully_shard`
|
|
||||||
apply_activation_checkpointing(
|
|
||||||
model,
|
|
||||||
checkpoint_wrapper_fn=functools.partial(
|
|
||||||
checkpoint_wrapper,
|
|
||||||
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
|
||||||
),
|
|
||||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
fsdp2_kwargs = {
|
|
||||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
|
||||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
|
||||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
|
||||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
|
||||||
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
|
|
||||||
}
|
|
||||||
|
|
||||||
model_has_params4bit = False
|
|
||||||
for _, param in model.named_parameters():
|
|
||||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
|
||||||
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
|
||||||
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
|
||||||
if param.__class__.__name__ == "Params4bit":
|
|
||||||
model_has_params4bit = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
|
||||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
|
||||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
|
||||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
|
||||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
|
||||||
|
|
||||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
|
||||||
# Also, these buffers aren't getting sharded by default
|
|
||||||
# We get the FQNs of all non-persistent buffers, to re-register them after
|
|
||||||
non_persistent_buffer_fqns = get_non_persistent_buffers(
|
|
||||||
model, recurse=True, fqns=True
|
|
||||||
)
|
|
||||||
original_non_persistent_buffers = copy.deepcopy(
|
|
||||||
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
|
||||||
)
|
|
||||||
# We move the model to meta device, as then sharding happens on meta device
|
|
||||||
model = model.to(torch.device("meta"))
|
|
||||||
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
|
||||||
# We assume `transformers` models have a `tie_weights` method if they support it
|
|
||||||
if hasattr(model, "tie_weights"):
|
|
||||||
model.tie_weights()
|
|
||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
|
||||||
|
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
|
||||||
log_bias_dtype_mismatch = False
|
|
||||||
if auto_wrap_policy is not None:
|
|
||||||
for module in get_module_children_bottom_up(model)[:-1]:
|
|
||||||
if is_peft_model and isinstance(module, LoraLayer):
|
|
||||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
|
||||||
module, fsdp2_kwargs
|
|
||||||
)
|
|
||||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
|
||||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
|
||||||
fully_shard(module, **fsdp2_kwargs)
|
|
||||||
|
|
||||||
fully_shard(model, **fsdp2_kwargs)
|
|
||||||
|
|
||||||
if log_bias_dtype_mismatch:
|
|
||||||
LOG.warning(
|
|
||||||
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
|
||||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
|
||||||
fsdp2_load_full_state_dict(
|
|
||||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
|
||||||
)
|
|
||||||
|
|
||||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
|
||||||
# We re-register the buffers, as they may not be in the state_dict
|
|
||||||
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
|
||||||
buffer_tensor = buffer_tensor.to(accelerator.device)
|
|
||||||
|
|
||||||
if "." in fqn:
|
|
||||||
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
|
||||||
parent_module = model.get_submodule(parent_fqn)
|
|
||||||
else:
|
|
||||||
local_buffer_name = fqn
|
|
||||||
parent_module = model
|
|
||||||
|
|
||||||
parent_module.register_buffer(
|
|
||||||
local_buffer_name, buffer_tensor, persistent=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
|
||||||
# Needs to be called both here and above
|
|
||||||
# removing this call makes the have slightly different loss
|
|
||||||
# removing the call above leads to extra memory usage as explained in the comment above
|
|
||||||
if hasattr(model, "tie_weights"):
|
|
||||||
model.tie_weights()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
|
||||||
import accelerate
|
|
||||||
|
|
||||||
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
|
|
||||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
|
||||||
setattr(
|
|
||||||
sys.modules["accelerate"],
|
|
||||||
"Accelerator.get_state_dict",
|
|
||||||
get_state_dict,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -162,14 +162,14 @@ def create_ring_flash_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def register_ring_attn(
|
def register_ring_attn(
|
||||||
context_parallel_size: int,
|
sequence_parallel_degree: int,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
ring_attn_func: RingAttnFunc | None,
|
ring_attn_func: RingAttnFunc | None,
|
||||||
):
|
):
|
||||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context_parallel_size: Sequence parallelism factor.
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
`varlen_llama3` `ring_flash_attn` implementation.
|
`varlen_llama3` `ring_flash_attn` implementation.
|
||||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||||
@@ -182,25 +182,25 @@ def register_ring_attn(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
"Enabling ring attention sequence parallelism: "
|
||||||
f"each sequence will be processed across {context_parallel_size} GPUs"
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert context_parallel_size <= world_size, (
|
assert sequence_parallel_degree <= world_size, (
|
||||||
f"context_parallel_size ({context_parallel_size}) "
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
)
|
)
|
||||||
assert world_size % context_parallel_size == 0, (
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
f"context_parallel_size ({context_parallel_size}) "
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign ranks to sequence parallel groups
|
# Assign ranks to sequence parallel groups
|
||||||
group_assignments = {}
|
group_assignments = {}
|
||||||
for i in range(world_size // context_parallel_size):
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
range(
|
range(
|
||||||
i * context_parallel_size,
|
i * sequence_parallel_degree,
|
||||||
(i + 1) * context_parallel_size,
|
(i + 1) * sequence_parallel_degree,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
@@ -299,12 +299,12 @@ def patch_prepare_data_loader():
|
|||||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||||
that includes sequence parallelism with the specified degree.
|
that includes sequence parallelism with the specified degree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context_parallel_size: The degree of sequence parallelism to use.
|
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||||
fsdp: Whether to use FSDP.
|
fsdp: Whether to use FSDP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -323,8 +323,8 @@ def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
|||||||
# Create device mesh with sequence parallelism
|
# Create device mesh with sequence parallelism
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
mesh_shape = (
|
mesh_shape = (
|
||||||
world_size // context_parallel_size,
|
world_size // sequence_parallel_degree,
|
||||||
context_parallel_size,
|
sequence_parallel_degree,
|
||||||
)
|
)
|
||||||
device_ids = list(range(world_size))
|
device_ids = list(range(world_size))
|
||||||
|
|
||||||
@@ -344,5 +344,5 @@ def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
|||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Successfully patched Accelerator._prepare_device_mesh "
|
"Successfully patched Accelerator._prepare_device_mesh "
|
||||||
f"with context_parallel_size={context_parallel_size}"
|
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.context_parallel_size > 1:
|
if cfg.sequence_parallel_degree > 1:
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
models.append(trainer.ref_model)
|
models.append(trainer.ref_model)
|
||||||
@@ -210,7 +210,7 @@ def execute_training(
|
|||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
SequenceParallelContextManager(
|
SequenceParallelContextManager(
|
||||||
models=models,
|
models=models,
|
||||||
context_parallel_size=cfg.context_parallel_size,
|
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ class SequenceParallelContextManager:
|
|||||||
Args:
|
Args:
|
||||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||||
hooks.
|
hooks.
|
||||||
context_parallel_size: Number of processes to split sequences over.
|
sequence_parallel_degree: Number of processes to split sequences over.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
@@ -179,14 +179,14 @@ class SequenceParallelContextManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models: list[nn.Module],
|
models: list[nn.Module],
|
||||||
context_parallel_size: int,
|
sequence_parallel_degree: int,
|
||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
gather_outputs: bool,
|
gather_outputs: bool,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.context_parallel_size = context_parallel_size
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
@@ -231,7 +231,7 @@ class SequenceParallelContextManager:
|
|||||||
def _register_ring_attn(self):
|
def _register_ring_attn(self):
|
||||||
# Initialize ring attn for sequence parallelism
|
# Initialize ring attn for sequence parallelism
|
||||||
register_ring_attn(
|
register_ring_attn(
|
||||||
context_parallel_size=self.context_parallel_size,
|
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||||
heads_k_stride=self.heads_k_stride,
|
heads_k_stride=self.heads_k_stride,
|
||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -644,19 +644,7 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
dp_shard_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of devices to shard across. If not set, will use all available devices."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sequence_parallel_degree: int | None = Field(
|
sequence_parallel_degree: int | None = Field(
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Deprecated: use `context_parallel_size` instead"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
context_parallel_size: int | None = Field(
|
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||||
|
|||||||
@@ -686,7 +686,7 @@ class RLValidationMixin:
|
|||||||
data.get("rl") == "grpo"
|
data.get("rl") == "grpo"
|
||||||
and data.get("trl", {})
|
and data.get("trl", {})
|
||||||
and data.get("trl").get("use_liger_loss")
|
and data.get("trl").get("use_liger_loss")
|
||||||
and data.get("context_parallel_size", 1) > 1
|
and data.get("sequence_parallel_degree", 1) > 1
|
||||||
):
|
):
|
||||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||||
return data
|
return data
|
||||||
@@ -913,14 +913,15 @@ class OptimizationValidationMixin:
|
|||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||||
if data.get("deepspeed"):
|
if not data.get("deepspeed"):
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||||
|
)
|
||||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
ds_config = json.load(ds_fin)
|
ds_config = json.load(ds_fin)
|
||||||
should_save = False
|
should_save = False
|
||||||
if "tensor_parallel" not in ds_config:
|
if "tensor_parallel" not in ds_config:
|
||||||
ds_config["tensor_parallel"] = {
|
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||||
"autotp_size": tensor_parallel_size
|
|
||||||
}
|
|
||||||
should_save = True
|
should_save = True
|
||||||
if (
|
if (
|
||||||
"gather_16bit_weights_on_model_save"
|
"gather_16bit_weights_on_model_save"
|
||||||
@@ -1234,13 +1235,13 @@ class ComplexValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_context_parallel_size(self):
|
def check_sequence_parallel_degree(self):
|
||||||
if not self.context_parallel_size:
|
if not self.sequence_parallel_degree:
|
||||||
self.context_parallel_size = 1
|
self.sequence_parallel_degree = 1
|
||||||
elif self.context_parallel_size > 1:
|
elif self.sequence_parallel_degree > 1:
|
||||||
if not self.flash_attention:
|
if not self.flash_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with context_parallel_size > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
@@ -1253,14 +1254,14 @@ class ComplexValidationMixin:
|
|||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Sequence parallelism (SP) is enabled with "
|
"Sequence parallelism (SP) is enabled with "
|
||||||
f"context_parallel_size={self.context_parallel_size}. "
|
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||||
"Please note that logged losses may differ slightly to the non-SP "
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
"losses due to transformers Trainer implementation details. "
|
"losses due to transformers Trainer implementation details. "
|
||||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
@@ -1271,7 +1272,7 @@ class ComplexValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_ring_attn_func(self):
|
def validate_ring_attn_func(self):
|
||||||
if getattr(self, "context_parallel_size", 1) == 1:
|
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.ring_attn_func is not None:
|
if self.ring_attn_func is not None:
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.context_parallel_size
|
* cfg.sequence_parallel_degree
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.floor(
|
math.floor(
|
||||||
data_loader_len
|
data_loader_len
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.context_parallel_size
|
* cfg.sequence_parallel_degree
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.ceil(
|
math.ceil(
|
||||||
len(train_dataset)
|
len(train_dataset)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.context_parallel_size
|
* cfg.sequence_parallel_degree
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
|||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
"dataloader_prefetch_factor": 2,
|
"dataloader_prefetch_factor": 2,
|
||||||
"context_parallel_size": 1,
|
"sequence_parallel_degree": 1,
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
# Dtype
|
# Dtype
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
|||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"ring_attn_func": ring_attn_func,
|
"ring_attn_func": ring_attn_func,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class TestRingAttention:
|
|||||||
|
|
||||||
# Call register_ring_attn with size 4
|
# Call register_ring_attn with size 4
|
||||||
register_ring_attn(
|
register_ring_attn(
|
||||||
context_parallel_size=4,
|
sequence_parallel_degree=4,
|
||||||
heads_k_stride=1,
|
heads_k_stride=1,
|
||||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||||
)
|
)
|
||||||
@@ -156,24 +156,24 @@ class TestConfigValidation:
|
|||||||
[
|
[
|
||||||
# Valid configuration
|
# Valid configuration
|
||||||
(
|
(
|
||||||
{"context_parallel_size": 2, "flash_attention": True},
|
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||||
{"context_parallel_size": 2, "flash_attention": True},
|
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||||
True,
|
True,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
# Default context_parallel_size
|
# Default sequence_parallel_degree
|
||||||
({}, {"context_parallel_size": 1}, True, None),
|
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||||
# Invalid: context_parallel_size > 1 without flash_attention
|
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||||
(
|
(
|
||||||
{"context_parallel_size": 2, "flash_attention": False},
|
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||||
None,
|
None,
|
||||||
False,
|
False,
|
||||||
"flash_attention: true must be set",
|
"flash_attention: true must be set",
|
||||||
),
|
),
|
||||||
# Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1
|
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
@@ -186,13 +186,13 @@ class TestConfigValidation:
|
|||||||
# Valid: Basic GRPO config
|
# Valid: Basic GRPO config
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": {"use_liger_loss": True},
|
"trl": {"use_liger_loss": True},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": TRLConfig(use_liger_loss=True),
|
"trl": TRLConfig(use_liger_loss=True),
|
||||||
@@ -204,7 +204,7 @@ class TestConfigValidation:
|
|||||||
(
|
(
|
||||||
{
|
{
|
||||||
"rl": "grpo",
|
"rl": "grpo",
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": {"use_liger_loss": True},
|
"trl": {"use_liger_loss": True},
|
||||||
@@ -262,7 +262,7 @@ class TestConfigValidation:
|
|||||||
|
|
||||||
# Apply updates to base config
|
# Apply updates to base config
|
||||||
cfg = base_cfg | {
|
cfg = base_cfg | {
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": sample_packing,
|
"sample_packing": sample_packing,
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ class TestConfigValidation:
|
|||||||
|
|
||||||
# Invalid configuration with invalid ring_attn_func
|
# Invalid configuration with invalid ring_attn_func
|
||||||
cfg = base_cfg | {
|
cfg = base_cfg | {
|
||||||
"context_parallel_size": 2,
|
"sequence_parallel_degree": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"ring_attn_func": "INVALID_FUNC",
|
"ring_attn_func": "INVALID_FUNC",
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user