Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp (#3019)
* fix for parallelism config from trainer * fix handling of parallelism_config w accelerate * add todo for removal * update to latest axolotl-contribs-mit for optimizer fix too * synchronize training after checkpoint save * dir spelling * use latest accelerate main * fix to not use partial state parallelism_config * more fixeS * use most recent accelerate fix * fix cpu_ram_efficient_loading to meta devices from rank 0 to prevent CPU RAM oom * improve handling of broadcasting fsdp2 state dict * support for openai chat template with thinking key as the reasoning trace * address PR feedback * refactor to remove dependency on PartialState for parallelism config * bump accelerate, gptoss fixes * limit meta fixes to fsdp2 for now * fixes for gpt oss * fixup examples, don't use cpu-ram-efficient-loading for now * remove problematic barrier * patch parallelism config * reorder comparison * device mesh fixes * make pure CP work * lint
This commit is contained in:
8
examples/distributed-parallel/README.md
Normal file
8
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Distributed Parallel
|
||||
|
||||
See the accompanying blog post: [Accelerate ND-Parallel: A guide to Efficient Multi-GPU Training](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||
|
||||
The examples provided are suitable for single node (8xGPU) SFT.
|
||||
|
||||
- Qwen 3 8B w/ FSDP + TP + CP: [YAML](./qwen3-8b-fsdp-tp-cp.yaml)
|
||||
- Llama 3.1 8B w/ HSDP + TP: [YAML](./llama-3_1-8b-hdsp-tp.yaml)
|
||||
47
examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 2
|
||||
dp_replicate_size: 2
|
||||
tensor_parallel_size: 2
|
||||
# context_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 2
|
||||
# dp_replicate_size: 1
|
||||
context_parallel_size: 2
|
||||
tensor_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
@@ -10,9 +10,10 @@ plugins:
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
split: train
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
@@ -20,6 +21,7 @@ output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
@@ -47,11 +49,12 @@ activation_offloading: true
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
- "<|return|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
@@ -60,3 +63,4 @@ fsdp_config:
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
@@ -10,9 +10,10 @@ plugins:
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
split: train
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
@@ -47,11 +48,12 @@ activation_offloading: true
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
- "<|return|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
@@ -60,3 +62,4 @@ fsdp_config:
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
|
||||
@@ -10,9 +10,10 @@ plugins:
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
split: train
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
@@ -26,11 +27,13 @@ lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
lora_target_parameters: # target the experts in the last two layers
|
||||
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
@@ -62,3 +65,4 @@ warmup_ratio: 0.1
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
- "<|return|>"
|
||||
|
||||
@@ -16,17 +16,18 @@ huggingface_hub>=0.33.0
|
||||
peft==0.17.0
|
||||
transformers==4.55.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.20.0
|
||||
trl==0.21.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.23.3
|
||||
gradio==5.41.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
@@ -68,6 +69,6 @@ torchao==0.12.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.4
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
|
||||
@@ -13,5 +13,5 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssExperts",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
@@ -39,6 +38,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -275,8 +275,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
partial_state = PartialState()
|
||||
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
if device_mesh is not None:
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
@@ -428,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
partial_state = PartialState()
|
||||
has_pc_attr = (
|
||||
hasattr(partial_state, "parallelism_config")
|
||||
and partial_state.parallelism_config
|
||||
)
|
||||
has_pc_key = (
|
||||
"parallelism_config"
|
||||
in partial_state._shared_state # pylint: disable=protected-access
|
||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||
"parallelism_config"
|
||||
]
|
||||
)
|
||||
use_configured_state = has_pc_attr or has_pc_key
|
||||
if self.cfg.accelerator_config:
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
**self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state,
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
|
||||
@@ -363,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
elif self.cfg.pad_to_sequence_len is None:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate.state import AcceleratorState
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -19,8 +22,10 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -515,7 +520,18 @@ class AxolotlTrainer(
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||
if not use_configured_state:
|
||||
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||
reset_partial_state=True
|
||||
)
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
@@ -524,8 +540,6 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
@@ -590,3 +604,64 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
else (PreTrainedModel, PeftModel)
|
||||
)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
):
|
||||
self.accelerator.unwrap_model(
|
||||
self.model, keep_torch_compile=False
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
super().create_accelerator_and_postprocess()
|
||||
if (
|
||||
self.accelerator.distributed_type == "FSDP"
|
||||
and self.accelerator.state.fsdp_plugin is None
|
||||
):
|
||||
# pylint: disable=protected-access
|
||||
# handle Context Parallelism without FSDP
|
||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||
PartialState().distributed_type = "MULTI_GPU"
|
||||
|
||||
@@ -13,7 +13,7 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
@@ -22,6 +22,7 @@ from peft import (
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -49,7 +50,11 @@ from axolotl.loaders.utils import (
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||
from axolotl.utils.distributed import (
|
||||
build_parallelism_config,
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -87,6 +92,7 @@ class ModelLoader:
|
||||
|
||||
use_parallel_config: bool | None = False
|
||||
parallelism_config: ParallelismConfig | None = None
|
||||
device_mesh: DeviceMesh | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -302,7 +308,10 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
# Handle DeepSpeed Zero3
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if (
|
||||
is_deepspeed_zero3_enabled()
|
||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||
):
|
||||
self._set_z3_leaf_modules()
|
||||
|
||||
# Apply gradient checkpointing if needed
|
||||
@@ -407,85 +416,12 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
self.cfg.dp_shard_size,
|
||||
self.cfg.dp_replicate_size,
|
||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
self.parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||
partial_state = PartialState()
|
||||
# fmt: off
|
||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||
self.parallelism_config
|
||||
)
|
||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||
device_mesh
|
||||
)
|
||||
# fmt: on
|
||||
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
||||
if parallelism_config:
|
||||
self.parallelism_config = parallelism_config
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
@@ -738,7 +674,7 @@ class ModelLoader:
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
self.model_kwargs["device_mesh"] = self.device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
@@ -754,6 +690,18 @@ class ModelLoader:
|
||||
elif self.is_qlora_and_fsdp_enabled:
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
self.cfg.tensor_parallel_size <= 1
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and self.cfg.fsdp_version == 2
|
||||
):
|
||||
# setting device_map for TP is not supported
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if local_rank == 0:
|
||||
self.model_kwargs["device_map"] = "cpu"
|
||||
else:
|
||||
self.model_kwargs["device_map"] = "meta"
|
||||
|
||||
if (
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
|
||||
@@ -104,6 +104,14 @@ class PatchManager:
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.context_parallel_size > 1 or (
|
||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||
):
|
||||
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
||||
patch_parallelism_config,
|
||||
)
|
||||
|
||||
patch_parallelism_config()
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||
full_tensor = None
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_sd[param_name]
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
||||
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
device_mesh = sharded_meta_param.device_mesh
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_tensor.to(device_mesh.device_type)
|
||||
else:
|
||||
full_tensor = torch.empty(
|
||||
sharded_meta_param.size(),
|
||||
device=device_mesh.device_type,
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
# Non-sharded parameters
|
||||
if _accelerator.is_main_process:
|
||||
sharded_param = full_tensor.to(torch.device("cuda"))
|
||||
else:
|
||||
# broadcast manually
|
||||
sharded_param = torch.empty_like(
|
||||
sharded_meta_param,
|
||||
device=torch.device("cuda"),
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
dist.broadcast(sharded_param, src=0)
|
||||
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
|
||||
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
workaround to allow parallelism config for pure CP
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from accelerate import DistributedType
|
||||
|
||||
|
||||
def _validate_accelerator(self, accelerator):
|
||||
_warnings = set()
|
||||
if not accelerator.multi_device and self.total_size == 1:
|
||||
# No distributed setup, valid parallelism config
|
||||
return
|
||||
|
||||
# We need this to ensure DDP works
|
||||
if self.total_size == 1:
|
||||
self._set_size("dp_replicate", accelerator.num_processes)
|
||||
|
||||
if self.total_size != accelerator.num_processes:
|
||||
raise ValueError(
|
||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||
f"dp_shard_size/tp_size/cp_size."
|
||||
)
|
||||
|
||||
# allow parallelism config when not using fsdp if using pure context parallelism
|
||||
allow_parallelism_config = False
|
||||
|
||||
if (
|
||||
self.cp_size > 1 # pylint: disable=chained-comparison
|
||||
and self.dp_shard_size <= 1
|
||||
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
||||
):
|
||||
allow_parallelism_config = True
|
||||
|
||||
if (
|
||||
self.total_size > 1
|
||||
and not allow_parallelism_config
|
||||
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
||||
):
|
||||
raise ValueError(
|
||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||
)
|
||||
|
||||
for parallelism, size in self._sizes.items():
|
||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||
_warnings.add(
|
||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||
)
|
||||
|
||||
if _warnings and accelerator.is_main_process:
|
||||
warnings.warn(
|
||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def patched_is_fsdp2(self) -> bool:
|
||||
"""
|
||||
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
||||
"""
|
||||
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
||||
return (
|
||||
self.distributed_type == DistributedType.FSDP
|
||||
and self.fsdp_plugin
|
||||
and self.fsdp_plugin.fsdp_version == 2
|
||||
)
|
||||
|
||||
|
||||
def patch_parallelism_config():
|
||||
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
@@ -36,6 +36,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"glm",
|
||||
"glm4",
|
||||
"smollm3",
|
||||
"gpt_oss",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
field_thinking: str = "reasoning_content",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
template_thinking_key: str | None = "reasoning_content",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
}
|
||||
if template_thinking_key and field_thinking:
|
||||
message_property_mappings[template_thinking_key] = field_thinking
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.field_thinking = field_thinking
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@@ -742,7 +747,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
# get the thinking content
|
||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||
transformed_message[self.prompter.template_thinking_key] = (
|
||||
thinking_content.strip()
|
||||
)
|
||||
|
||||
# take remainder of the content
|
||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||
@@ -953,6 +960,10 @@ class StrategyLoader:
|
||||
None,
|
||||
),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
||||
"template_thinking_key": dataset_config.get(
|
||||
"template_thinking_key", "reasoning_content"
|
||||
),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
|
||||
@@ -218,6 +218,7 @@ def execute_training(
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -274,7 +275,7 @@ def save_trained_model(
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||
if cfg.fsdp_config or cfg.fsdp:
|
||||
if cfg.fsdp_config.final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||
|
||||
@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
squash_position_ids: bool = False
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features: List[List[dict]] = [features]
|
||||
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
elif feature == "position_ids" and self.squash_position_ids:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
# concatenate, get total length and create arange of new total position ids
|
||||
position_ids = np.concatenate(arrays)
|
||||
total_length = position_ids.shape[0]
|
||||
position_ids = np.arange(total_length)
|
||||
out_features[i][feature] = position_ids
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
|
||||
@@ -5,8 +5,8 @@ import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from torch import nn
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
@@ -194,6 +194,7 @@ class SequenceParallelContextManager:
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
):
|
||||
self.models = models
|
||||
self.context_parallel_size = context_parallel_size
|
||||
@@ -201,6 +202,7 @@ class SequenceParallelContextManager:
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self.gather_outputs = gather_outputs
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
self._register_ring_attn()
|
||||
|
||||
@@ -240,9 +242,8 @@ class SequenceParallelContextManager:
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
partial_state = PartialState()
|
||||
register_ring_attn_from_device_mesh(
|
||||
device_mesh=partial_state.device_mesh,
|
||||
device_mesh=self.device_mesh,
|
||||
context_parallel_dim=("cp",),
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
|
||||
@@ -8,6 +8,7 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import ParallelismConfig
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -290,3 +291,77 @@ def reduce_and_broadcast(fn1, fn2):
|
||||
# Use compute_and_broadcast to compute the reduced value on the main process
|
||||
# and then broadcast it to all ranks
|
||||
return compute_and_broadcast(lambda: fn2(gathered_values))
|
||||
|
||||
|
||||
def build_parallelism_config(cfg):
|
||||
pc_kwargs = _get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
cfg.tensor_parallel_size,
|
||||
cfg.context_parallel_size,
|
||||
cfg.dp_shard_size,
|
||||
cfg.dp_replicate_size,
|
||||
bool(cfg.fsdp or cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||
|
||||
return parallelism_config, device_mesh
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
@@ -118,6 +118,18 @@ class SFTDataset(BaseModel):
|
||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||
},
|
||||
)
|
||||
field_thinking: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
|
||||
},
|
||||
)
|
||||
template_thinking_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key the chat template expects that indicates the reasoning trace."
|
||||
},
|
||||
)
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
|
||||
@@ -1147,6 +1147,19 @@ class ModelCompatibilityValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gpt_oss_fsdp_loading(cls, data):
|
||||
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
||||
if (
|
||||
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
||||
is True
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class ComplexValidationMixin:
|
||||
"""Complex validation methods that involve multiple systems."""
|
||||
|
||||
@@ -597,6 +597,25 @@ def setup_fsdp_envs(cfg):
|
||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
||||
|
||||
|
||||
def setup_parallelism_envs(cfg):
|
||||
set_accelerate_parallelism_config = False
|
||||
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
|
||||
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
|
||||
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
|
||||
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
||||
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
||||
if set_accelerate_parallelism_config:
|
||||
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
if not check_cuda_p2p_ib_support():
|
||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||
@@ -615,6 +634,7 @@ def prepare_optim_env(cfg):
|
||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||
setup_deepspeed_env(cfg, stage=stage)
|
||||
|
||||
setup_parallelism_envs(cfg)
|
||||
setup_torch_compile_env(cfg)
|
||||
|
||||
if cfg.fp8:
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
@@ -193,15 +194,13 @@ class TestModelsUtils:
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = (
|
||||
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
|
||||
if expected[0] > 1:
|
||||
|
||||
Reference in New Issue
Block a user