Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp (#3019)

* fix for parallelism config from trainer

* fix handling of parallelism_config w accelerate

* add todo for removal

* update to latest axolotl-contribs-mit for optimizer fix too

* synchronize training after checkpoint save

* dir spelling

* use latest accelerate main

* fix to not use partial state parallelism_config

* more fixeS

* use most recent accelerate fix

* fix cpu_ram_efficient_loading to meta devices from rank 0 to prevent CPU RAM oom

* improve handling of broadcasting fsdp2 state dict

* support for openai chat template with thinking key as the reasoning trace

* address PR feedback

* refactor to remove dependency on PartialState for parallelism config

* bump accelerate, gptoss fixes

* limit meta fixes to fsdp2 for now

* fixes for gpt oss

* fixup examples, don't use cpu-ram-efficient-loading for now

* remove problematic barrier

* patch parallelism config

* reorder comparison

* device mesh fixes

* make pure CP work

* lint
This commit is contained in:
Wing Lian
2025-08-07 21:22:15 -04:00
committed by GitHub
parent ca796fb56e
commit 9d5c95db6f
26 changed files with 534 additions and 148 deletions

View File

@@ -0,0 +1,8 @@
# Distributed Parallel
See the accompanying blog post: [Accelerate ND-Parallel: A guide to Efficient Multi-GPU Training](https://huggingface.co/blog/accelerate-nd-parallel)
The examples provided are suitable for single node (8xGPU) SFT.
- Qwen 3 8B w/ FSDP + TP + CP: [YAML](./qwen3-8b-fsdp-tp-cp.yaml)
- Llama 3.1 8B w/ HSDP + TP: [YAML](./llama-3_1-8b-hdsp-tp.yaml)

View File

@@ -0,0 +1,47 @@
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
dp_replicate_size: 2
tensor_parallel_size: 2
# context_parallel_size: 2
dataset_prepared_path: last_run_prepared
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1

View File

@@ -0,0 +1,46 @@
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
# dp_replicate_size: 1
context_parallel_size: 2
tensor_parallel_size: 2
dataset_prepared_path: last_run_prepared
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:

View File

@@ -10,9 +10,10 @@ plugins:
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
split: train
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
@@ -20,6 +21,7 @@ output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -47,11 +49,12 @@ activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
- "<|return|>"
fsdp_version: 2
fsdp_config:
@@ -60,3 +63,4 @@ fsdp_config:
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -1,5 +1,5 @@
base_model: openai/gpt-oss-20b
use_kernels: true
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
@@ -10,9 +10,10 @@ plugins:
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
split: train
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
@@ -47,11 +48,12 @@ activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
- "<|return|>"
fsdp_version: 2
fsdp_config:
@@ -60,3 +62,4 @@ fsdp_config:
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -10,9 +10,10 @@ plugins:
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
split: train
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
@@ -26,11 +27,13 @@ lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
lora_target_parameters: # target the experts in the last two layers
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
@@ -62,3 +65,4 @@ warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"
- "<|return|>"

View File

@@ -16,17 +16,18 @@ huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
tokenizers>=0.21.1
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
accelerate==1.10.0
datasets==4.0.0
deepspeed>=0.17.0
trl==0.20.0
trl==0.21.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.23.3
gradio==5.41.1
modal==1.0.2
pydantic==2.10.6
@@ -68,6 +69,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.4
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3

View File

@@ -13,5 +13,5 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssExperts",
"gpt_oss": "GptOssDecoderLayer",
}

View File

@@ -24,7 +24,6 @@ from pathlib import Path
from typing import Any
import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
@@ -39,6 +38,7 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
@@ -275,8 +275,9 @@ class TrainerBuilderBase(abc.ABC):
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
optimizer_kwargs["device_mesh"] = device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
@@ -428,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
**self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:

View File

@@ -363,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
elif self.cfg.pad_to_sequence_len is None:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple

View File

@@ -10,8 +10,11 @@ from functools import partial, wraps
from typing import Any, Callable, Literal, Optional
import datasets
import safetensors
import torch
from accelerate.state import AcceleratorState
from datasets import Dataset
from peft import PeftModel
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -19,8 +22,10 @@ from torch.utils.data import (
Sampler,
SequentialSampler,
)
from transformers import Trainer
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -515,7 +520,18 @@ class AxolotlTrainer(
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
res = super().create_accelerator_and_postprocess()
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
@@ -524,8 +540,6 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
@@ -590,3 +604,64 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
else (PreTrainedModel, PeftModel)
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
):
self.accelerator.unwrap_model(
self.model, keep_torch_compile=False
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -2,6 +2,7 @@
Mixin for correctly saving fsdp
"""
from accelerate import PartialState
from transformers import Trainer
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)
def create_accelerator_and_postprocess(self):
super().create_accelerator_and_postprocess()
if (
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
PartialState().distributed_type = "MULTI_GPU"

View File

@@ -13,7 +13,7 @@ import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import PartialState, init_empty_weights
from accelerate import init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import (
PeftConfig,
@@ -22,6 +22,7 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
@@ -49,7 +50,11 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
from axolotl.utils.distributed import (
build_parallelism_config,
get_device_count,
get_device_type,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -87,6 +92,7 @@ class ModelLoader:
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
device_mesh: DeviceMesh | None = None
def __init__(
self,
@@ -302,7 +308,10 @@ class ModelLoader:
)
# Handle DeepSpeed Zero3
if is_deepspeed_zero3_enabled():
if (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed
@@ -407,85 +416,12 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
if parallelism_config:
self.parallelism_config = parallelism_config
self.device_mesh = device_mesh
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
@@ -738,7 +674,7 @@ class ModelLoader:
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = PartialState().device_mesh
self.model_kwargs["device_mesh"] = self.device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
@@ -754,6 +690,18 @@ class ModelLoader:
elif self.is_qlora_and_fsdp_enabled:
skip_move_to_device = True
if (
self.cfg.tensor_parallel_size <= 1
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.cfg.fsdp_version == 2
):
# setting device_map for TP is not supported
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if local_rank == 0:
self.model_kwargs["device_map"] = "cpu"
else:
self.model_kwargs["device_map"] = "meta"
if (
self.is_qlora_and_fsdp_enabled
and self.cfg.fsdp_config.cpu_ram_efficient_loading

View File

@@ -104,6 +104,14 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
from axolotl.monkeypatch.accelerate.parallelism_config import (
patch_parallelism_config,
)
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2

View File

@@ -7,6 +7,7 @@ import functools
import sys
import torch
import torch.distributed as dist
from torch import nn
from axolotl.utils.bench import log_gpu_memory_usage
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
for param_name, sharded_meta_param in meta_sharded_sd.items():
full_tensor = None
if _accelerator.is_main_process:
full_tensor = full_sd[param_name]
full_tensor = full_tensor.to(sharded_meta_param.dtype)
if hasattr(sharded_meta_param, "device_mesh"):
device_mesh = sharded_meta_param.device_mesh
if _accelerator.is_main_process:
full_tensor = full_tensor.to(device_mesh.device_type)
else:
full_tensor = torch.empty(
sharded_meta_param.size(),
device=device_mesh.device_type,
dtype=sharded_meta_param.dtype,
)
sharded_param = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
device_mesh,
sharded_meta_param.placements,
src_data_rank=0,
)
else:
sharded_param = full_tensor
# Non-sharded parameters
if _accelerator.is_main_process:
sharded_param = full_tensor.to(torch.device("cuda"))
else:
# broadcast manually
sharded_param = torch.empty_like(
sharded_meta_param,
device=torch.device("cuda"),
dtype=sharded_meta_param.dtype,
)
dist.broadcast(sharded_param, src=0)
if offload_to_cpu:
sharded_param = sharded_param.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_param)
del full_tensor
full_sd[param_name] = None
model.load_state_dict(sharded_sd, assign=True, strict=True)
end_time = time.time()
LOG.debug(

View File

@@ -0,0 +1,77 @@
"""
workaround to allow parallelism config for pure CP
"""
# pylint: disable=protected-access
import os
import warnings
from accelerate import DistributedType
def _validate_accelerator(self, accelerator):
_warnings = set()
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
# We need this to ensure DDP works
if self.total_size == 1:
self._set_size("dp_replicate", accelerator.num_processes)
if self.total_size != accelerator.num_processes:
raise ValueError(
f"ParallelismConfig total_size ({self.total_size}) does not match "
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
f"dp_shard_size/tp_size/cp_size."
)
# allow parallelism config when not using fsdp if using pure context parallelism
allow_parallelism_config = False
if (
self.cp_size > 1 # pylint: disable=chained-comparison
and self.dp_shard_size <= 1
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
):
allow_parallelism_config = True
if (
self.total_size > 1
and not allow_parallelism_config
and not (accelerator.is_fsdp2 or accelerator.multi_device)
):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
)
for parallelism, size in self._sizes.items():
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
_warnings.add(
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
)
if _warnings and accelerator.is_main_process:
warnings.warn(
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
UserWarning,
)
def patched_is_fsdp2(self) -> bool:
"""
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
"""
# The new logic checks if fsdp_plugin exists before accessing its attributes
return (
self.distributed_type == DistributedType.FSDP
and self.fsdp_plugin
and self.fsdp_plugin.fsdp_version == 2
)
def patch_parallelism_config():
from accelerate.accelerator import AcceleratorState, ParallelismConfig
ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)

View File

@@ -36,6 +36,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"gpt_oss",
]

View File

@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
field_thinking: str = "reasoning_content",
roles: dict[str, list[str]] | None = None,
template_thinking_key: str | None = "reasoning_content",
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
):
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.field_thinking = field_thinking
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -742,7 +747,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
transformed_message[self.prompter.template_thinking_key] = (
thinking_content.strip()
)
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
@@ -953,6 +960,10 @@ class StrategyLoader:
None,
),
"field_messages": dataset_config.get("field_messages", "messages"),
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
"template_thinking_key": dataset_config.get(
"template_thinking_key", "reasoning_content"
),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.

View File

@@ -218,6 +218,7 @@ def execute_training(
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
device_mesh=trainer.accelerator.torch_device_mesh,
)
)
@@ -274,7 +275,7 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if trainer.is_fsdp_enabled:
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type

View File

@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
Collator for multipack specific to the using the BatchSampler
"""
squash_position_ids: bool = False
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features: List[List[dict]] = [features]
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
elif feature == "position_ids" and self.squash_position_ids:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
# concatenate, get total length and create arange of new total position ids
position_ids = np.concatenate(arrays)
total_length = position_ids.shape[0]
position_ids = np.arange(total_length)
out_features[i][feature] = position_ids
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item

View File

@@ -5,8 +5,8 @@ import inspect
import torch
import torch.distributed as dist
from accelerate import PartialState
from torch import nn
from torch.distributed import DeviceMesh
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
@@ -194,6 +194,7 @@ class SequenceParallelContextManager:
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
gather_outputs: bool,
device_mesh: DeviceMesh | None = None,
):
self.models = models
self.context_parallel_size = context_parallel_size
@@ -201,6 +202,7 @@ class SequenceParallelContextManager:
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self.gather_outputs = gather_outputs
self.device_mesh = device_mesh
self._register_ring_attn()
@@ -240,9 +242,8 @@ class SequenceParallelContextManager:
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
partial_state = PartialState()
register_ring_attn_from_device_mesh(
device_mesh=partial_state.device_mesh,
device_mesh=self.device_mesh,
context_parallel_dim=("cp",),
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,

View File

@@ -8,6 +8,7 @@ from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from accelerate.utils import ParallelismConfig
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
@@ -290,3 +291,77 @@ def reduce_and_broadcast(fn1, fn2):
# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))
def build_parallelism_config(cfg):
pc_kwargs = _get_parallel_config_kwargs(
get_world_size(),
cfg.tensor_parallel_size,
cfg.context_parallel_size,
cfg.dp_shard_size,
cfg.dp_replicate_size,
bool(cfg.fsdp or cfg.fsdp_config),
)
if pc_kwargs:
parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = parallelism_config.build_device_mesh("cuda")
return parallelism_config, device_mesh
return None, None
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs

View File

@@ -118,6 +118,18 @@ class SFTDataset(BaseModel):
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
},
)
field_thinking: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
},
)
template_thinking_key: str | None = Field(
default=None,
json_schema_extra={
"description": "The key the chat template expects that indicates the reasoning trace."
},
)
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings

View File

@@ -1147,6 +1147,19 @@ class ModelCompatibilityValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_gpt_oss_fsdp_loading(cls, data):
if data.get("model_quantization_config", "") == "Mxfp4Config":
if (
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
is True
):
raise ValueError(
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
)
return data
class ComplexValidationMixin:
"""Complex validation methods that involve multiple systems."""

View File

@@ -597,6 +597,25 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = False
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
@@ -615,6 +634,7 @@ def prepare_optim_env(cfg):
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
setup_parallelism_envs(cfg)
setup_torch_compile_env(cfg)
if cfg.fp8:

View File

@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import _get_parallel_config_kwargs
class TestModelsUtils:
@@ -193,15 +194,13 @@ class TestModelsUtils:
is_fsdp,
expected,
):
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
if expected[0] > 1: