Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp (#3019)
* fix for parallelism config from trainer * fix handling of parallelism_config w accelerate * add todo for removal * update to latest axolotl-contribs-mit for optimizer fix too * synchronize training after checkpoint save * dir spelling * use latest accelerate main * fix to not use partial state parallelism_config * more fixeS * use most recent accelerate fix * fix cpu_ram_efficient_loading to meta devices from rank 0 to prevent CPU RAM oom * improve handling of broadcasting fsdp2 state dict * support for openai chat template with thinking key as the reasoning trace * address PR feedback * refactor to remove dependency on PartialState for parallelism config * bump accelerate, gptoss fixes * limit meta fixes to fsdp2 for now * fixes for gpt oss * fixup examples, don't use cpu-ram-efficient-loading for now * remove problematic barrier * patch parallelism config * reorder comparison * device mesh fixes * make pure CP work * lint
This commit is contained in:
8
examples/distributed-parallel/README.md
Normal file
8
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Distributed Parallel
|
||||||
|
|
||||||
|
See the accompanying blog post: [Accelerate ND-Parallel: A guide to Efficient Multi-GPU Training](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||||
|
|
||||||
|
The examples provided are suitable for single node (8xGPU) SFT.
|
||||||
|
|
||||||
|
- Qwen 3 8B w/ FSDP + TP + CP: [YAML](./qwen3-8b-fsdp-tp-cp.yaml)
|
||||||
|
- Llama 3.1 8B w/ HSDP + TP: [YAML](./llama-3_1-8b-hdsp-tp.yaml)
|
||||||
47
examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
base_model: meta-llama/Llama-3.1-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
dp_shard_size: 2
|
||||||
|
dp_replicate_size: 2
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
# context_parallel_size: 2
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/ndp-out/
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-6
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
base_model: Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
dp_shard_size: 2
|
||||||
|
# dp_replicate_size: 1
|
||||||
|
context_parallel_size: 2
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/ndp-out/
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1 # must be 1 when using context parallel
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-6
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
@@ -10,9 +10,10 @@ plugins:
|
|||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: winglian/pirate-ultrachat-10k
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
@@ -20,6 +21,7 @@ output_dir: ./outputs/gpt-oss-out/
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
@@ -47,11 +49,12 @@ activation_offloading: true
|
|||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
- "<|end|>"
|
- "<|end|>"
|
||||||
|
- "<|return|>"
|
||||||
|
|
||||||
fsdp_version: 2
|
fsdp_version: 2
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
@@ -60,3 +63,4 @@ fsdp_config:
|
|||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
reshard_after_forward: true
|
reshard_after_forward: true
|
||||||
|
# cpu_ram_efficient_loading: true
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
base_model: openai/gpt-oss-20b
|
base_model: openai/gpt-oss-20b
|
||||||
use_kernels: true
|
use_kernels: false
|
||||||
model_quantization_config: Mxfp4Config
|
model_quantization_config: Mxfp4Config
|
||||||
model_quantization_config_kwargs:
|
model_quantization_config_kwargs:
|
||||||
dequantize: true
|
dequantize: true
|
||||||
@@ -10,9 +10,10 @@ plugins:
|
|||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: winglian/pirate-ultrachat-10k
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
@@ -47,11 +48,12 @@ activation_offloading: true
|
|||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
- "<|end|>"
|
- "<|end|>"
|
||||||
|
- "<|return|>"
|
||||||
|
|
||||||
fsdp_version: 2
|
fsdp_version: 2
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
@@ -60,3 +62,4 @@ fsdp_config:
|
|||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
reshard_after_forward: true
|
reshard_after_forward: true
|
||||||
|
# cpu_ram_efficient_loading: true
|
||||||
|
|||||||
@@ -10,9 +10,10 @@ plugins:
|
|||||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: winglian/pirate-ultrachat-10k
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
@@ -26,11 +27,13 @@ lora_r: 8
|
|||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_target_parameters: # target the experts in the last two layers
|
|
||||||
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
# TODO: not supported for now, see peft#2710
|
||||||
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
#lora_target_parameters: # target the experts in the last two layers
|
||||||
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
|
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
@@ -62,3 +65,4 @@ warmup_ratio: 0.1
|
|||||||
special_tokens:
|
special_tokens:
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
- "<|end|>"
|
- "<|end|>"
|
||||||
|
- "<|return|>"
|
||||||
|
|||||||
@@ -16,17 +16,18 @@ huggingface_hub>=0.33.0
|
|||||||
peft==0.17.0
|
peft==0.17.0
|
||||||
transformers==4.55.0
|
transformers==4.55.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
accelerate==1.10.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.20.0
|
trl==0.21.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.5
|
||||||
kernels==0.9.0
|
kernels==0.9.0
|
||||||
|
trackio
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.41.1
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
@@ -68,6 +69,6 @@ torchao==0.12.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.4
|
axolotl-contribs-mit==0.0.5
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.8.3
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
"gpt_oss": "GptOssExperts",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import PartialState
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
@@ -39,6 +38,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -275,8 +275,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
partial_state = PartialState()
|
_, device_mesh = build_parallelism_config(self.cfg)
|
||||||
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
|
if device_mesh is not None:
|
||||||
|
optimizer_kwargs["device_mesh"] = device_mesh
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -428,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
partial_state = PartialState()
|
|
||||||
has_pc_attr = (
|
|
||||||
hasattr(partial_state, "parallelism_config")
|
|
||||||
and partial_state.parallelism_config
|
|
||||||
)
|
|
||||||
has_pc_key = (
|
|
||||||
"parallelism_config"
|
|
||||||
in partial_state._shared_state # pylint: disable=protected-access
|
|
||||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
|
||||||
"parallelism_config"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
use_configured_state = has_pc_attr or has_pc_key
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
use_configured_state = self.cfg.accelerator_config.pop(
|
|
||||||
"use_configured_state", use_configured_state
|
|
||||||
)
|
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
**self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||||
use_configured_state=use_configured_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
)
|
)
|
||||||
else:
|
elif self.cfg.pad_to_sequence_len is None:
|
||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
|
|||||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
|||||||
from typing import Any, Callable, Literal, Optional
|
from typing import Any, Callable, Literal, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate.state import AcceleratorState
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft import PeftModel
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
@@ -19,8 +22,10 @@ from torch.utils.data import (
|
|||||||
Sampler,
|
Sampler,
|
||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -515,7 +520,18 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
res = super().create_accelerator_and_postprocess()
|
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||||
|
accelerator_config = self.args.accelerator_config.to_dict()
|
||||||
|
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||||
|
if not use_configured_state:
|
||||||
|
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||||
|
reset_partial_state=True
|
||||||
|
)
|
||||||
|
|
||||||
|
super().create_accelerator_and_postprocess()
|
||||||
|
|
||||||
|
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||||
|
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if (
|
if (
|
||||||
@@ -524,8 +540,6 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def additional_accelerator_args(
|
def additional_accelerator_args(
|
||||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||||
@@ -590,3 +604,64 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
supported_classes = (
|
||||||
|
(PreTrainedModel,)
|
||||||
|
if not is_peft_available()
|
||||||
|
else (PreTrainedModel, PeftModel)
|
||||||
|
)
|
||||||
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
|
# They can then be reloaded using `from_pretrained()`
|
||||||
|
if not isinstance(self.model, supported_classes):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
if isinstance(
|
||||||
|
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||||
|
supported_classes,
|
||||||
|
):
|
||||||
|
self.accelerator.unwrap_model(
|
||||||
|
self.model, keep_torch_compile=False
|
||||||
|
).save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
safe_serialization=self.args.save_safetensors,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||||
|
)
|
||||||
|
if self.args.save_safetensors:
|
||||||
|
safetensors.torch.save_file(
|
||||||
|
state_dict,
|
||||||
|
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||||
|
metadata={"format": "pt"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
safe_serialization=self.args.save_safetensors,
|
||||||
|
is_main_process=self.accelerator.is_main_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.processing_class is not None:
|
||||||
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
elif (
|
||||||
|
self.data_collator is not None
|
||||||
|
and hasattr(self.data_collator, "tokenizer")
|
||||||
|
and self.data_collator.tokenizer is not None
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||||
|
)
|
||||||
|
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||||
|
# Good practice: save your training arguments together with the trained model
|
||||||
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Mixin for correctly saving fsdp
|
Mixin for correctly saving fsdp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from accelerate import PartialState
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
|
|||||||
):
|
):
|
||||||
state_dict = self.accelerator.get_state_dict(self.model)
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
super()._save(output_dir, state_dict=state_dict)
|
super()._save(output_dir, state_dict=state_dict)
|
||||||
|
|
||||||
|
def create_accelerator_and_postprocess(self):
|
||||||
|
super().create_accelerator_and_postprocess()
|
||||||
|
if (
|
||||||
|
self.accelerator.distributed_type == "FSDP"
|
||||||
|
and self.accelerator.state.fsdp_plugin is None
|
||||||
|
):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# handle Context Parallelism without FSDP
|
||||||
|
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||||
|
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||||
|
PartialState().distributed_type = "MULTI_GPU"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import PartialState, init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.parallelism_config import ParallelismConfig
|
from accelerate.parallelism_config import ParallelismConfig
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
@@ -22,6 +22,7 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
|
from torch.distributed import DeviceMesh
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@@ -49,7 +50,11 @@ from axolotl.loaders.utils import (
|
|||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
from axolotl.utils.distributed import (
|
||||||
|
build_parallelism_config,
|
||||||
|
get_device_count,
|
||||||
|
get_device_type,
|
||||||
|
)
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -87,6 +92,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
use_parallel_config: bool | None = False
|
use_parallel_config: bool | None = False
|
||||||
parallelism_config: ParallelismConfig | None = None
|
parallelism_config: ParallelismConfig | None = None
|
||||||
|
device_mesh: DeviceMesh | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -302,7 +308,10 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle DeepSpeed Zero3
|
# Handle DeepSpeed Zero3
|
||||||
if is_deepspeed_zero3_enabled():
|
if (
|
||||||
|
is_deepspeed_zero3_enabled()
|
||||||
|
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||||
|
):
|
||||||
self._set_z3_leaf_modules()
|
self._set_z3_leaf_modules()
|
||||||
|
|
||||||
# Apply gradient checkpointing if needed
|
# Apply gradient checkpointing if needed
|
||||||
@@ -407,85 +416,12 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_parallel_config_kwargs(
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
context_parallel_size: int = 1,
|
|
||||||
dp_shard_size: int | None = None,
|
|
||||||
dp_replicate_size: int | None = None,
|
|
||||||
is_fsdp: bool = False,
|
|
||||||
):
|
|
||||||
pc_kwargs = {}
|
|
||||||
remaining_world_size = world_size
|
|
||||||
|
|
||||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
|
||||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
|
||||||
|
|
||||||
if context_parallel_size and context_parallel_size > 1:
|
|
||||||
pc_kwargs["cp_size"] = context_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // context_parallel_size
|
|
||||||
|
|
||||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if dp_replicate_size and dp_replicate_size > 1:
|
|
||||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
|
||||||
|
|
||||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
|
||||||
if not is_fsdp:
|
|
||||||
raise ValueError(
|
|
||||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
|
||||||
"Please ensure you have configured FSDP using fsdp_config."
|
|
||||||
)
|
|
||||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_shard_size
|
|
||||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
|
||||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
|
||||||
f"{pc_kwargs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return pc_kwargs
|
|
||||||
|
|
||||||
def _set_parallel_config(self):
|
def _set_parallel_config(self):
|
||||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
||||||
get_world_size(),
|
if parallelism_config:
|
||||||
self.cfg.tensor_parallel_size,
|
self.parallelism_config = parallelism_config
|
||||||
self.cfg.context_parallel_size,
|
self.device_mesh = device_mesh
|
||||||
self.cfg.dp_shard_size,
|
|
||||||
self.cfg.dp_replicate_size,
|
|
||||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
|
||||||
)
|
|
||||||
|
|
||||||
if pc_kwargs:
|
|
||||||
self.parallelism_config = ParallelismConfig(
|
|
||||||
**pc_kwargs,
|
|
||||||
)
|
|
||||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
|
||||||
partial_state = PartialState()
|
|
||||||
# fmt: off
|
|
||||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
|
||||||
self.parallelism_config
|
|
||||||
)
|
|
||||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
|
||||||
device_mesh
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
@@ -738,7 +674,7 @@ class ModelLoader:
|
|||||||
if self.cfg.tensor_parallel_size > 1:
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
self.model_kwargs["tp_plan"] = "auto"
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
self.model_kwargs["device_mesh"] = self.device_mesh
|
||||||
if "device_map" in self.model_kwargs:
|
if "device_map" in self.model_kwargs:
|
||||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
@@ -754,6 +690,18 @@ class ModelLoader:
|
|||||||
elif self.is_qlora_and_fsdp_enabled:
|
elif self.is_qlora_and_fsdp_enabled:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cfg.tensor_parallel_size <= 1
|
||||||
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
|
and self.cfg.fsdp_version == 2
|
||||||
|
):
|
||||||
|
# setting device_map for TP is not supported
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
if local_rank == 0:
|
||||||
|
self.model_kwargs["device_map"] = "cpu"
|
||||||
|
else:
|
||||||
|
self.model_kwargs["device_map"] = "meta"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.is_qlora_and_fsdp_enabled
|
self.is_qlora_and_fsdp_enabled
|
||||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
|
|||||||
@@ -104,6 +104,14 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
|
if self.cfg.context_parallel_size > 1 or (
|
||||||
|
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
||||||
|
patch_parallelism_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_parallelism_config()
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import functools
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
|
|||||||
|
|
||||||
meta_sharded_sd = model.state_dict()
|
meta_sharded_sd = model.state_dict()
|
||||||
sharded_sd = {}
|
sharded_sd = {}
|
||||||
for param_name, full_tensor in full_sd.items():
|
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
full_tensor = None
|
||||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
if _accelerator.is_main_process:
|
||||||
|
full_tensor = full_sd[param_name]
|
||||||
|
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
||||||
|
|
||||||
if hasattr(sharded_meta_param, "device_mesh"):
|
if hasattr(sharded_meta_param, "device_mesh"):
|
||||||
|
device_mesh = sharded_meta_param.device_mesh
|
||||||
|
if _accelerator.is_main_process:
|
||||||
|
full_tensor = full_tensor.to(device_mesh.device_type)
|
||||||
|
else:
|
||||||
|
full_tensor = torch.empty(
|
||||||
|
sharded_meta_param.size(),
|
||||||
|
device=device_mesh.device_type,
|
||||||
|
dtype=sharded_meta_param.dtype,
|
||||||
|
)
|
||||||
sharded_param = distribute_tensor(
|
sharded_param = distribute_tensor(
|
||||||
full_tensor,
|
full_tensor,
|
||||||
sharded_meta_param.device_mesh,
|
device_mesh,
|
||||||
sharded_meta_param.placements,
|
sharded_meta_param.placements,
|
||||||
src_data_rank=0,
|
src_data_rank=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sharded_param = full_tensor
|
# Non-sharded parameters
|
||||||
|
if _accelerator.is_main_process:
|
||||||
|
sharded_param = full_tensor.to(torch.device("cuda"))
|
||||||
|
else:
|
||||||
|
# broadcast manually
|
||||||
|
sharded_param = torch.empty_like(
|
||||||
|
sharded_meta_param,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
dtype=sharded_meta_param.dtype,
|
||||||
|
)
|
||||||
|
dist.broadcast(sharded_param, src=0)
|
||||||
|
|
||||||
if offload_to_cpu:
|
if offload_to_cpu:
|
||||||
sharded_param = sharded_param.cpu()
|
sharded_param = sharded_param.cpu()
|
||||||
|
|
||||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||||
|
|
||||||
del full_tensor
|
del full_tensor
|
||||||
full_sd[param_name] = None
|
full_sd[param_name] = None
|
||||||
|
|
||||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
|
|||||||
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
workaround to allow parallelism config for pure CP
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from accelerate import DistributedType
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_accelerator(self, accelerator):
|
||||||
|
_warnings = set()
|
||||||
|
if not accelerator.multi_device and self.total_size == 1:
|
||||||
|
# No distributed setup, valid parallelism config
|
||||||
|
return
|
||||||
|
|
||||||
|
# We need this to ensure DDP works
|
||||||
|
if self.total_size == 1:
|
||||||
|
self._set_size("dp_replicate", accelerator.num_processes)
|
||||||
|
|
||||||
|
if self.total_size != accelerator.num_processes:
|
||||||
|
raise ValueError(
|
||||||
|
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||||
|
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||||
|
f"dp_shard_size/tp_size/cp_size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# allow parallelism config when not using fsdp if using pure context parallelism
|
||||||
|
allow_parallelism_config = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cp_size > 1 # pylint: disable=chained-comparison
|
||||||
|
and self.dp_shard_size <= 1
|
||||||
|
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
||||||
|
):
|
||||||
|
allow_parallelism_config = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.total_size > 1
|
||||||
|
and not allow_parallelism_config
|
||||||
|
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for parallelism, size in self._sizes.items():
|
||||||
|
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||||
|
_warnings.add(
|
||||||
|
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _warnings and accelerator.is_main_process:
|
||||||
|
warnings.warn(
|
||||||
|
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_is_fsdp2(self) -> bool:
|
||||||
|
"""
|
||||||
|
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
||||||
|
"""
|
||||||
|
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
||||||
|
return (
|
||||||
|
self.distributed_type == DistributedType.FSDP
|
||||||
|
and self.fsdp_plugin
|
||||||
|
and self.fsdp_plugin.fsdp_version == 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_parallelism_config():
|
||||||
|
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
||||||
|
|
||||||
|
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||||
|
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||||
@@ -36,6 +36,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
"smollm3",
|
"smollm3",
|
||||||
|
"gpt_oss",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
field_messages: str = "messages",
|
field_messages: str = "messages",
|
||||||
field_system: str = "system",
|
field_system: str = "system",
|
||||||
field_tools: str = "tools",
|
field_tools: str = "tools",
|
||||||
|
field_thinking: str = "reasoning_content",
|
||||||
roles: dict[str, list[str]] | None = None,
|
roles: dict[str, list[str]] | None = None,
|
||||||
|
template_thinking_key: str | None = "reasoning_content",
|
||||||
chat_template_kwargs: dict[str, Any] | None = None,
|
chat_template_kwargs: dict[str, Any] | None = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
"role": "role",
|
"role": "role",
|
||||||
"content": "content",
|
"content": "content",
|
||||||
"reasoning_content": "reasoning_content",
|
|
||||||
}
|
}
|
||||||
|
if template_thinking_key and field_thinking:
|
||||||
|
message_property_mappings[template_thinking_key] = field_thinking
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.field_messages = field_messages
|
self.field_messages = field_messages
|
||||||
self.field_system = field_system
|
self.field_system = field_system
|
||||||
self.field_tools = field_tools
|
self.field_tools = field_tools
|
||||||
|
self.field_thinking = field_thinking
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin | None = processor
|
self.processor: ProcessorMixin | None = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||||
|
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
@@ -742,7 +747,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
# get the thinking content
|
# get the thinking content
|
||||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
transformed_message[self.prompter.template_thinking_key] = (
|
||||||
|
thinking_content.strip()
|
||||||
|
)
|
||||||
|
|
||||||
# take remainder of the content
|
# take remainder of the content
|
||||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||||
@@ -953,6 +960,10 @@ class StrategyLoader:
|
|||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||||
|
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
||||||
|
"template_thinking_key": dataset_config.get(
|
||||||
|
"template_thinking_key", "reasoning_content"
|
||||||
|
),
|
||||||
"roles": dataset_config.get("roles"),
|
"roles": dataset_config.get("roles"),
|
||||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ def execute_training(
|
|||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
gather_outputs=cfg.rl is RLType.GRPO,
|
gather_outputs=cfg.rl is RLType.GRPO,
|
||||||
|
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -274,7 +275,7 @@ def save_trained_model(
|
|||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||||
if cfg.fsdp_config or cfg.fsdp:
|
if cfg.fsdp_config or cfg.fsdp:
|
||||||
if cfg.fsdp_config.final_state_dict_type:
|
if cfg.fsdp_config.final_state_dict_type:
|
||||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||||
|
|||||||
@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
Collator for multipack specific to the using the BatchSampler
|
Collator for multipack specific to the using the BatchSampler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
squash_position_ids: bool = False
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if not isinstance(features[0], list):
|
if not isinstance(features[0], list):
|
||||||
features: List[List[dict]] = [features]
|
features: List[List[dict]] = [features]
|
||||||
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
elif feature == "position_ids" and self.squash_position_ids:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
]
|
||||||
|
# concatenate, get total length and create arange of new total position ids
|
||||||
|
position_ids = np.concatenate(arrays)
|
||||||
|
total_length = position_ids.shape[0]
|
||||||
|
position_ids = np.arange(total_length)
|
||||||
|
out_features[i][feature] = position_ids
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import inspect
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import PartialState
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributed import DeviceMesh
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput
|
||||||
@@ -194,6 +194,7 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
gather_outputs: bool,
|
gather_outputs: bool,
|
||||||
|
device_mesh: DeviceMesh | None = None,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.context_parallel_size = context_parallel_size
|
self.context_parallel_size = context_parallel_size
|
||||||
@@ -201,6 +202,7 @@ class SequenceParallelContextManager:
|
|||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
self.gather_outputs = gather_outputs
|
self.gather_outputs = gather_outputs
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
self._register_ring_attn()
|
self._register_ring_attn()
|
||||||
|
|
||||||
@@ -240,9 +242,8 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
def _register_ring_attn(self):
|
def _register_ring_attn(self):
|
||||||
# Initialize ring attn for sequence parallelism
|
# Initialize ring attn for sequence parallelism
|
||||||
partial_state = PartialState()
|
|
||||||
register_ring_attn_from_device_mesh(
|
register_ring_attn_from_device_mesh(
|
||||||
device_mesh=partial_state.device_mesh,
|
device_mesh=self.device_mesh,
|
||||||
context_parallel_dim=("cp",),
|
context_parallel_dim=("cp",),
|
||||||
heads_k_stride=self.heads_k_stride,
|
heads_k_stride=self.heads_k_stride,
|
||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from datetime import timedelta
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import PartialState
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import ParallelismConfig
|
||||||
from transformers.utils.import_utils import (
|
from transformers.utils.import_utils import (
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
@@ -290,3 +291,77 @@ def reduce_and_broadcast(fn1, fn2):
|
|||||||
# Use compute_and_broadcast to compute the reduced value on the main process
|
# Use compute_and_broadcast to compute the reduced value on the main process
|
||||||
# and then broadcast it to all ranks
|
# and then broadcast it to all ranks
|
||||||
return compute_and_broadcast(lambda: fn2(gathered_values))
|
return compute_and_broadcast(lambda: fn2(gathered_values))
|
||||||
|
|
||||||
|
|
||||||
|
def build_parallelism_config(cfg):
|
||||||
|
pc_kwargs = _get_parallel_config_kwargs(
|
||||||
|
get_world_size(),
|
||||||
|
cfg.tensor_parallel_size,
|
||||||
|
cfg.context_parallel_size,
|
||||||
|
cfg.dp_shard_size,
|
||||||
|
cfg.dp_replicate_size,
|
||||||
|
bool(cfg.fsdp or cfg.fsdp_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
if pc_kwargs:
|
||||||
|
parallelism_config = ParallelismConfig(
|
||||||
|
**pc_kwargs,
|
||||||
|
)
|
||||||
|
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||||
|
|
||||||
|
return parallelism_config, device_mesh
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parallel_config_kwargs(
|
||||||
|
world_size: int,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
context_parallel_size: int = 1,
|
||||||
|
dp_shard_size: int | None = None,
|
||||||
|
dp_replicate_size: int | None = None,
|
||||||
|
is_fsdp: bool = False,
|
||||||
|
):
|
||||||
|
pc_kwargs = {}
|
||||||
|
remaining_world_size = world_size
|
||||||
|
|
||||||
|
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||||
|
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||||
|
|
||||||
|
if context_parallel_size and context_parallel_size > 1:
|
||||||
|
pc_kwargs["cp_size"] = context_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // context_parallel_size
|
||||||
|
|
||||||
|
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if dp_replicate_size and dp_replicate_size > 1:
|
||||||
|
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||||
|
|
||||||
|
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||||
|
if not is_fsdp:
|
||||||
|
raise ValueError(
|
||||||
|
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||||
|
"Please ensure you have configured FSDP using fsdp_config."
|
||||||
|
)
|
||||||
|
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_shard_size
|
||||||
|
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||||
|
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||||
|
f"{pc_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pc_kwargs
|
||||||
|
|||||||
@@ -118,6 +118,18 @@ class SFTDataset(BaseModel):
|
|||||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
field_thinking: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
template_thinking_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The key the chat template expects that indicates the reasoning trace."
|
||||||
|
},
|
||||||
|
)
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
message_field_role: str | None = None
|
message_field_role: str | None = None
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
|
|||||||
@@ -1147,6 +1147,19 @@ class ModelCompatibilityValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_gpt_oss_fsdp_loading(cls, data):
|
||||||
|
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
||||||
|
if (
|
||||||
|
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ComplexValidationMixin:
|
class ComplexValidationMixin:
|
||||||
"""Complex validation methods that involve multiple systems."""
|
"""Complex validation methods that involve multiple systems."""
|
||||||
|
|||||||
@@ -597,6 +597,25 @@ def setup_fsdp_envs(cfg):
|
|||||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parallelism_envs(cfg):
|
||||||
|
set_accelerate_parallelism_config = False
|
||||||
|
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
|
||||||
|
set_accelerate_parallelism_config = True
|
||||||
|
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
|
||||||
|
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
|
||||||
|
set_accelerate_parallelism_config = True
|
||||||
|
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
|
||||||
|
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
|
||||||
|
set_accelerate_parallelism_config = True
|
||||||
|
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
|
||||||
|
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
|
||||||
|
set_accelerate_parallelism_config = True
|
||||||
|
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
||||||
|
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
||||||
|
if set_accelerate_parallelism_config:
|
||||||
|
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
if not check_cuda_p2p_ib_support():
|
if not check_cuda_p2p_ib_support():
|
||||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
@@ -615,6 +634,7 @@ def prepare_optim_env(cfg):
|
|||||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||||
setup_deepspeed_env(cfg, stage=stage)
|
setup_deepspeed_env(cfg, stage=stage)
|
||||||
|
|
||||||
|
setup_parallelism_envs(cfg)
|
||||||
setup_torch_compile_env(cfg)
|
setup_torch_compile_env(cfg)
|
||||||
|
|
||||||
if cfg.fp8:
|
if cfg.fp8:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
|
|||||||
|
|
||||||
from axolotl.loaders import ModelLoader
|
from axolotl.loaders import ModelLoader
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
||||||
|
|
||||||
|
|
||||||
class TestModelsUtils:
|
class TestModelsUtils:
|
||||||
@@ -193,15 +194,13 @@ class TestModelsUtils:
|
|||||||
is_fsdp,
|
is_fsdp,
|
||||||
expected,
|
expected,
|
||||||
):
|
):
|
||||||
res = (
|
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||||
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
world_size,
|
||||||
world_size,
|
tensor_parallel_size,
|
||||||
tensor_parallel_size,
|
context_parallel_size,
|
||||||
context_parallel_size,
|
dp_shard_size,
|
||||||
dp_shard_size,
|
dp_replicate_size,
|
||||||
dp_replicate_size,
|
is_fsdp,
|
||||||
is_fsdp,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if expected[0] > 1:
|
if expected[0] > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user