diff --git a/examples/distributed-parallel/README.md b/examples/distributed-parallel/README.md new file mode 100644 index 000000000..5aff54cd1 --- /dev/null +++ b/examples/distributed-parallel/README.md @@ -0,0 +1,8 @@ +# Distributed Parallel + +See the accompanying blog post: [Accelerate ND-Parallel: A guide to Efficient Multi-GPU Training](https://huggingface.co/blog/accelerate-nd-parallel) + +The examples provided are suitable for single node (8xGPU) SFT. + +- Qwen 3 8B w/ FSDP + TP + CP: [YAML](./qwen3-8b-fsdp-tp-cp.yaml) +- Llama 3.1 8B w/ HSDP + TP: [YAML](./llama-3_1-8b-hdsp-tp.yaml) diff --git a/examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml b/examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml new file mode 100644 index 000000000..5b3246f74 --- /dev/null +++ b/examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml @@ -0,0 +1,47 @@ +base_model: meta-llama/Llama-3.1-8B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +dp_shard_size: 2 +dp_replicate_size: 2 +tensor_parallel_size: 2 +# context_parallel_size: 2 + +dataset_prepared_path: last_run_prepared + +special_tokens: + pad_token: <|end_of_text|> + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + reshard_after_forward: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/ndp-out/ + +sequence_len: 2048 +sample_packing: true +flash_attention: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 2 +optimizer: adamw_torch_fused +lr_scheduler: constant_with_warmup +learning_rate: 2e-6 + +bf16: true +tf32: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 diff --git a/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml new file mode 100644 index 000000000..584a33f44 --- /dev/null +++ b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml @@ -0,0 +1,46 @@ +base_model: Qwen/Qwen3-8B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +dp_shard_size: 2 +# dp_replicate_size: 1 +context_parallel_size: 2 +tensor_parallel_size: 2 + +dataset_prepared_path: last_run_prepared + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen3DecoderLayer + reshard_after_forward: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/ndp-out/ + +sequence_len: 8192 +sample_packing: true +flash_attention: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 # must be 1 when using context parallel +num_epochs: 2 +optimizer: adamw_torch_fused +lr_scheduler: constant_with_warmup +learning_rate: 2e-6 + +bf16: true +tf32: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 + +special_tokens: diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml index d55a272ba..b861876d1 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -10,9 +10,10 @@ plugins: experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding datasets: - - path: winglian/pirate-ultrachat-10k + - path: HuggingFaceH4/Multilingual-Thinking type: chat_template - split: train + field_thinking: thinking + template_thinking_key: thinking dataset_prepared_path: last_run_prepared val_set_size: 0 @@ -20,6 +21,7 @@ output_dir: ./outputs/gpt-oss-out/ sequence_len: 4096 sample_packing: true +pad_to_sequence_len: true wandb_project: wandb_entity: @@ -47,11 +49,12 @@ activation_offloading: true logging_steps: 1 saves_per_epoch: 1 -warmup_ratio: 0.1 +warmup_ratio: 0.03 special_tokens: eot_tokens: - "<|end|>" + - "<|return|>" fsdp_version: 2 fsdp_config: @@ -60,3 +63,4 @@ fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: GptOssDecoderLayer reshard_after_forward: true +# cpu_ram_efficient_loading: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml index f9f2c1dce..6ec99304a 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -1,5 +1,5 @@ base_model: openai/gpt-oss-20b -use_kernels: true +use_kernels: false model_quantization_config: Mxfp4Config model_quantization_config_kwargs: dequantize: true @@ -10,9 +10,10 @@ plugins: experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding datasets: - - path: winglian/pirate-ultrachat-10k + - path: HuggingFaceH4/Multilingual-Thinking type: chat_template - split: train + field_thinking: thinking + template_thinking_key: thinking dataset_prepared_path: last_run_prepared val_set_size: 0 @@ -47,11 +48,12 @@ activation_offloading: true logging_steps: 1 saves_per_epoch: 1 -warmup_ratio: 0.1 +warmup_ratio: 0.03 special_tokens: eot_tokens: - "<|end|>" + - "<|return|>" fsdp_version: 2 fsdp_config: @@ -60,3 +62,4 @@ fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: GptOssDecoderLayer reshard_after_forward: true +# cpu_ram_efficient_loading: true diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml index f7c332dfe..6016ce712 100644 --- a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -10,9 +10,10 @@ plugins: experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding datasets: - - path: winglian/pirate-ultrachat-10k + - path: HuggingFaceH4/Multilingual-Thinking type: chat_template - split: train + field_thinking: thinking + template_thinking_key: thinking dataset_prepared_path: last_run_prepared val_set_size: 0 @@ -26,11 +27,13 @@ lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters lora_target_linear: true -lora_target_parameters: # target the experts in the last two layers - - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj" - - "22._checkpoint_wrapped_module.mlp.experts.down_proj" - - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj" - - "23._checkpoint_wrapped_module.mlp.experts.down_proj" + +# TODO: not supported for now, see peft#2710 +#lora_target_parameters: # target the experts in the last two layers +# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj" +# - "22._checkpoint_wrapped_module.mlp.experts.down_proj" +# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj" +# - "23._checkpoint_wrapped_module.mlp.experts.down_proj" wandb_project: wandb_entity: @@ -62,3 +65,4 @@ warmup_ratio: 0.1 special_tokens: eot_tokens: - "<|end|>" + - "<|return|>" diff --git a/requirements.txt b/requirements.txt index 0103ba919..370bf5a5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,17 +16,18 @@ huggingface_hub>=0.33.0 peft==0.17.0 transformers==4.55.0 tokenizers>=0.21.1 -accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152 +accelerate==1.10.0 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.20.0 +trl==0.21.0 hf_xet==1.1.5 kernels==0.9.0 +trackio optimum==1.16.2 hf_transfer sentencepiece -gradio==5.23.3 +gradio==5.41.1 modal==1.0.2 pydantic==2.10.6 @@ -68,6 +69,6 @@ torchao==0.12.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 -axolotl-contribs-mit==0.0.4 +axolotl-contribs-mit==0.0.5 mistral-common==1.8.3 diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 58d557e7e..ce945e670 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -13,5 +13,5 @@ MOE_ARCH_BLOCK = { "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", - "gpt_oss": "GptOssExperts", + "gpt_oss": "GptOssDecoderLayer", } diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 0472acee9..e1f649715 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -24,7 +24,6 @@ from pathlib import Path from typing import Any import torch -from accelerate import PartialState from transformers import ( TrainerCallback, ) @@ -39,6 +38,7 @@ from axolotl.utils.callbacks import ( SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback +from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.schemas.enums import CustomSupportedOptimizers LOG = logging.getLogger(__name__) @@ -275,8 +275,9 @@ class TrainerBuilderBase(abc.ABC): optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] optimizer_kwargs.update(adam_kwargs) - partial_state = PartialState() - optimizer_kwargs["device_mesh"] = partial_state.device_mesh + _, device_mesh = build_parallelism_config(self.cfg) + if device_mesh is not None: + optimizer_kwargs["device_mesh"] = device_mesh elif self.cfg.optimizer == "optimi_adamw": from optimi import AdamW @@ -428,30 +429,12 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode def _configure_accelerator_config(self, training_args_kwargs: dict): - partial_state = PartialState() - has_pc_attr = ( - hasattr(partial_state, "parallelism_config") - and partial_state.parallelism_config - ) - has_pc_key = ( - "parallelism_config" - in partial_state._shared_state # pylint: disable=protected-access - and partial_state._shared_state[ # pylint: disable=protected-access - "parallelism_config" - ] - ) - use_configured_state = has_pc_attr or has_pc_key if self.cfg.accelerator_config: - use_configured_state = self.cfg.accelerator_config.pop( - "use_configured_state", use_configured_state - ) training_args_kwargs["accelerator_config"] = AcceleratorConfig( - use_configured_state=use_configured_state, **self.cfg.accelerator_config + **self.cfg.accelerator_config ) else: - training_args_kwargs["accelerator_config"] = AcceleratorConfig( - use_configured_state=use_configured_state, - ) + training_args_kwargs["accelerator_config"] = AcceleratorConfig() def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.activation_offloading is True: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index db35a2412..e5bc21762 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -363,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil( self.cfg.sequence_len / multiple ) - else: + elif self.cfg.pad_to_sequence_len is None: # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = multiple diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3540fb6a1..0f9f6e4c4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -10,8 +10,11 @@ from functools import partial, wraps from typing import Any, Callable, Literal, Optional import datasets +import safetensors import torch +from accelerate.state import AcceleratorState from datasets import Dataset +from peft import PeftModel from torch.utils.data import ( BatchSampler, DataLoader, @@ -19,8 +22,10 @@ from torch.utils.data import ( Sampler, SequentialSampler, ) -from transformers import Trainer +from transformers import PreTrainedModel, Trainer +from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available from trl.trainer.utils import pad_to_length from typing_extensions import override @@ -515,7 +520,18 @@ class AxolotlTrainer( @wraps(Trainer.create_accelerator_and_postprocess) def create_accelerator_and_postprocess(self): - res = super().create_accelerator_and_postprocess() + # cleanup the PartialState states so Accelerate automatically configures everything from the env vars + accelerator_config = self.args.accelerator_config.to_dict() + use_configured_state = accelerator_config.get("use_configured_state", False) + if not use_configured_state: + AcceleratorState._reset_state( # pylint: disable=protected-access + reset_partial_state=True + ) + + super().create_accelerator_and_postprocess() + + # now we need to put parallelism_config back on the PartialState since we rely on that info in other places + # PartialState().parallelism_config = self.accelerator.state.parallelism_config if self.is_fsdp_enabled: if ( @@ -524,8 +540,6 @@ class AxolotlTrainer( ): self.accelerator.state.fsdp_plugin.limit_all_gathers = True - return res - # pylint: disable=unused-argument def additional_accelerator_args( self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs @@ -590,3 +604,64 @@ class AxolotlTrainer( output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) + + # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + LOG.info(f"Saving model checkpoint to {output_dir}") + supported_classes = ( + (PreTrainedModel,) + if not is_peft_available() + else (PreTrainedModel, PeftModel) + ) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + if isinstance( + self.accelerator.unwrap_model(self.model, keep_torch_compile=False), + supported_classes, + ): + self.accelerator.unwrap_model( + self.model, keep_torch_compile=False + ).save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + ) + else: + LOG.info( + "Trainer.model is not a `PreTrainedModel`, only saving its state dict." + ) + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, + os.path.join(output_dir, SAFE_WEIGHTS_NAME), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + is_main_process=self.accelerator.is_main_process, + ) + + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + elif ( + self.data_collator is not None + and hasattr(self.data_collator, "tokenizer") + and self.data_collator.tokenizer is not None + ): + LOG.info( + "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" + ) + self.data_collator.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/src/axolotl/core/trainers/mixins/distributed_parallel.py b/src/axolotl/core/trainers/mixins/distributed_parallel.py index d0f0f53df..d163e4eb5 100644 --- a/src/axolotl/core/trainers/mixins/distributed_parallel.py +++ b/src/axolotl/core/trainers/mixins/distributed_parallel.py @@ -2,6 +2,7 @@ Mixin for correctly saving fsdp """ +from accelerate import PartialState from transformers import Trainer @@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer): ): state_dict = self.accelerator.get_state_dict(self.model) super()._save(output_dir, state_dict=state_dict) + + def create_accelerator_and_postprocess(self): + super().create_accelerator_and_postprocess() + if ( + self.accelerator.distributed_type == "FSDP" + and self.accelerator.state.fsdp_plugin is None + ): + # pylint: disable=protected-access + # handle Context Parallelism without FSDP + self.accelerator.state.distributed_type = "MULTI_GPU" + self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU" + PartialState().distributed_type = "MULTI_GPU" diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 7061e1ff3..95a56b326 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -13,7 +13,7 @@ import peft import torch import transformers import transformers.modeling_utils -from accelerate import PartialState, init_empty_weights +from accelerate import init_empty_weights from accelerate.parallelism_config import ParallelismConfig from peft import ( PeftConfig, @@ -22,6 +22,7 @@ from peft import ( PeftModelForCausalLM, prepare_model_for_kbit_training, ) +from torch.distributed import DeviceMesh from transformers import ( AutoModelForCausalLM, AutoModelForVision2Seq, @@ -49,7 +50,11 @@ from axolotl.loaders.utils import ( from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size +from axolotl.utils.distributed import ( + build_parallelism_config, + get_device_count, + get_device_type, +) from axolotl.utils.logging import get_logger from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.schemas.enums import RLType @@ -87,6 +92,7 @@ class ModelLoader: use_parallel_config: bool | None = False parallelism_config: ParallelismConfig | None = None + device_mesh: DeviceMesh | None = None def __init__( self, @@ -302,7 +308,10 @@ class ModelLoader: ) # Handle DeepSpeed Zero3 - if is_deepspeed_zero3_enabled(): + if ( + is_deepspeed_zero3_enabled() + or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" + ): self._set_z3_leaf_modules() # Apply gradient checkpointing if needed @@ -407,85 +416,12 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() - @staticmethod - def _get_parallel_config_kwargs( - world_size: int, - tensor_parallel_size: int = 1, - context_parallel_size: int = 1, - dp_shard_size: int | None = None, - dp_replicate_size: int | None = None, - is_fsdp: bool = False, - ): - pc_kwargs = {} - remaining_world_size = world_size - - if tensor_parallel_size and tensor_parallel_size > 1: - pc_kwargs["tp_size"] = tensor_parallel_size - remaining_world_size = remaining_world_size // tensor_parallel_size - - if context_parallel_size and context_parallel_size > 1: - pc_kwargs["cp_size"] = context_parallel_size - remaining_world_size = remaining_world_size // context_parallel_size - - if dp_shard_size is None and dp_replicate_size in (None, 1): - if remaining_world_size > 1: - pc_kwargs["dp_shard_size"] = remaining_world_size - remaining_world_size = 1 - - if dp_replicate_size and dp_replicate_size > 1: - pc_kwargs["dp_replicate_size"] = dp_replicate_size - remaining_world_size = remaining_world_size // dp_replicate_size - - if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: - if not is_fsdp: - raise ValueError( - "dp_shard_size was configured without a corresponding fsdp_config! " - "Please ensure you have configured FSDP using fsdp_config." - ) - pc_kwargs["dp_shard_size"] = dp_shard_size - remaining_world_size = remaining_world_size // dp_shard_size - if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs: - pc_kwargs["dp_replicate_size"] = remaining_world_size - remaining_world_size = 1 - - if remaining_world_size > 1: - if "dp_shard_size" not in pc_kwargs and is_fsdp: - pc_kwargs["dp_shard_size"] = remaining_world_size - remaining_world_size = 1 - - if remaining_world_size > 1: - raise ValueError( - f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n" - f"{pc_kwargs}" - ) - - return pc_kwargs - def _set_parallel_config(self): """Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator""" - pc_kwargs = ModelLoader._get_parallel_config_kwargs( - get_world_size(), - self.cfg.tensor_parallel_size, - self.cfg.context_parallel_size, - self.cfg.dp_shard_size, - self.cfg.dp_replicate_size, - bool(self.cfg.fsdp or self.cfg.fsdp_config), - ) - - if pc_kwargs: - self.parallelism_config = ParallelismConfig( - **pc_kwargs, - ) - device_mesh = self.parallelism_config.build_device_mesh("cuda") - partial_state = PartialState() - # fmt: off - partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access - self.parallelism_config - ) - partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access - device_mesh - ) - # fmt: on + parallelism_config, device_mesh = build_parallelism_config(self.cfg) + if parallelism_config: + self.parallelism_config = parallelism_config + self.device_mesh = device_mesh def _set_auto_model_loader(self): """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` @@ -738,7 +674,7 @@ class ModelLoader: if self.cfg.tensor_parallel_size > 1: self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size self.model_kwargs["tp_plan"] = "auto" - self.model_kwargs["device_mesh"] = PartialState().device_mesh + self.model_kwargs["device_mesh"] = self.device_mesh if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] # not compatible with `tp_plan` @@ -754,6 +690,18 @@ class ModelLoader: elif self.is_qlora_and_fsdp_enabled: skip_move_to_device = True + if ( + self.cfg.tensor_parallel_size <= 1 + and self.cfg.fsdp_config.cpu_ram_efficient_loading + and self.cfg.fsdp_version == 2 + ): + # setting device_map for TP is not supported + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if local_rank == 0: + self.model_kwargs["device_map"] = "cpu" + else: + self.model_kwargs["device_map"] = "meta" + if ( self.is_qlora_and_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 4273f3cce..047eb20fd 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -104,6 +104,14 @@ class PatchManager: def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" + if self.cfg.context_parallel_size > 1 or ( + self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2" + ): + from axolotl.monkeypatch.accelerate.parallelism_config import ( + patch_parallelism_config, + ) + + patch_parallelism_config() if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2": from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2 diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index af262d18f..efc388294 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -7,6 +7,7 @@ import functools import sys import torch +import torch.distributed as dist from torch import nn from axolotl.utils.bench import log_gpu_memory_usage @@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict( meta_sharded_sd = model.state_dict() sharded_sd = {} - for param_name, full_tensor in full_sd.items(): - sharded_meta_param = meta_sharded_sd.get(param_name) - full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda")) + for param_name, sharded_meta_param in meta_sharded_sd.items(): + full_tensor = None + if _accelerator.is_main_process: + full_tensor = full_sd[param_name] + full_tensor = full_tensor.to(sharded_meta_param.dtype) + if hasattr(sharded_meta_param, "device_mesh"): + device_mesh = sharded_meta_param.device_mesh + if _accelerator.is_main_process: + full_tensor = full_tensor.to(device_mesh.device_type) + else: + full_tensor = torch.empty( + sharded_meta_param.size(), + device=device_mesh.device_type, + dtype=sharded_meta_param.dtype, + ) sharded_param = distribute_tensor( full_tensor, - sharded_meta_param.device_mesh, + device_mesh, sharded_meta_param.placements, src_data_rank=0, ) else: - sharded_param = full_tensor + # Non-sharded parameters + if _accelerator.is_main_process: + sharded_param = full_tensor.to(torch.device("cuda")) + else: + # broadcast manually + sharded_param = torch.empty_like( + sharded_meta_param, + device=torch.device("cuda"), + dtype=sharded_meta_param.dtype, + ) + dist.broadcast(sharded_param, src=0) if offload_to_cpu: sharded_param = sharded_param.cpu() sharded_sd[param_name] = nn.Parameter(sharded_param) + del full_tensor full_sd[param_name] = None + model.load_state_dict(sharded_sd, assign=True, strict=True) end_time = time.time() LOG.debug( diff --git a/src/axolotl/monkeypatch/accelerate/parallelism_config.py b/src/axolotl/monkeypatch/accelerate/parallelism_config.py new file mode 100644 index 000000000..e3cafc87d --- /dev/null +++ b/src/axolotl/monkeypatch/accelerate/parallelism_config.py @@ -0,0 +1,77 @@ +""" +workaround to allow parallelism config for pure CP +""" + +# pylint: disable=protected-access +import os +import warnings + +from accelerate import DistributedType + + +def _validate_accelerator(self, accelerator): + _warnings = set() + if not accelerator.multi_device and self.total_size == 1: + # No distributed setup, valid parallelism config + return + + # We need this to ensure DDP works + if self.total_size == 1: + self._set_size("dp_replicate", accelerator.num_processes) + + if self.total_size != accelerator.num_processes: + raise ValueError( + f"ParallelismConfig total_size ({self.total_size}) does not match " + f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ " + f"dp_shard_size/tp_size/cp_size." + ) + + # allow parallelism config when not using fsdp if using pure context parallelism + allow_parallelism_config = False + + if ( + self.cp_size > 1 # pylint: disable=chained-comparison + and self.dp_shard_size <= 1 + and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true" + ): + allow_parallelism_config = True + + if ( + self.total_size > 1 + and not allow_parallelism_config + and not (accelerator.is_fsdp2 or accelerator.multi_device) + ): + raise ValueError( + f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}." + ) + + for parallelism, size in self._sizes.items(): + if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None: + _warnings.add( + f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored." + ) + + if _warnings and accelerator.is_main_process: + warnings.warn( + "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings), + UserWarning, + ) + + +def patched_is_fsdp2(self) -> bool: + """ + Patched version of is_fsdp2 that guards against a None fsdp_plugin. + """ + # The new logic checks if fsdp_plugin exists before accessing its attributes + return ( + self.distributed_type == DistributedType.FSDP + and self.fsdp_plugin + and self.fsdp_plugin.fsdp_version == 2 + ) + + +def patch_parallelism_config(): + from accelerate.accelerator import AcceleratorState, ParallelismConfig + + ParallelismConfig._validate_accelerator = _validate_accelerator + AcceleratorState.is_fsdp2 = property(patched_is_fsdp2) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 9dc04c7b4..5fc5ae856 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -36,6 +36,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "glm", "glm4", "smollm3", + "gpt_oss", ] diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 80fe9275e..8241dd385 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter): field_messages: str = "messages", field_system: str = "system", field_tools: str = "tools", + field_thinking: str = "reasoning_content", roles: dict[str, list[str]] | None = None, + template_thinking_key: str | None = "reasoning_content", chat_template_kwargs: dict[str, Any] | None = None, drop_system_message: bool = False, ): @@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter): message_property_mappings = { "role": "role", "content": "content", - "reasoning_content": "reasoning_content", } + if template_thinking_key and field_thinking: + message_property_mappings[template_thinking_key] = field_thinking if roles: self.roles = {s: t for t, sources in roles.items() for s in sources} @@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter): self.field_messages = field_messages self.field_system = field_system self.field_tools = field_tools + self.field_thinking = field_thinking self.tokenizer = tokenizer self.processor: ProcessorMixin | None = processor self.chat_template = chat_template self.chat_template_kwargs = chat_template_kwargs or {} + self.template_thinking_key: str = template_thinking_key or "reasoning_content" self.max_length = max_length self.drop_system_message = drop_system_message @@ -742,7 +747,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): # get the thinking content thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx] - transformed_message["reasoning_content"] = thinking_content.strip() + transformed_message[self.prompter.template_thinking_key] = ( + thinking_content.strip() + ) # take remainder of the content # strip whitespace from beginning of the remainder (thinking tokens) @@ -953,6 +960,10 @@ class StrategyLoader: None, ), "field_messages": dataset_config.get("field_messages", "messages"), + "field_thinking": dataset_config.get("field_thinking", "reasoning_content"), + "template_thinking_key": dataset_config.get( + "template_thinking_key", "reasoning_content" + ), "roles": dataset_config.get("roles"), "drop_system_message": dataset_config.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a693236d3..e8a2cbabe 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -218,6 +218,7 @@ def execute_training( ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, gather_outputs=cfg.rl is RLType.GRPO, + device_mesh=trainer.accelerator.torch_device_mesh, ) ) @@ -274,7 +275,7 @@ def save_trained_model( # final model weights have already been saved by `ReLoRACallback.on_train_end` return - if trainer.is_fsdp_enabled: + if trainer.is_fsdp_enabled or cfg.fsdp_config: if cfg.fsdp_config or cfg.fsdp: if cfg.fsdp_config.final_state_dict_type: state_dict_type = cfg.fsdp_config.final_state_dict_type diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 25a871b2b..55e630fbe 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): Collator for multipack specific to the using the BatchSampler """ + squash_position_ids: bool = False + def __call__(self, features, return_tensors=None): if not isinstance(features[0], list): features: List[List[dict]] = [features] @@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if feature in item ] out_features[i][feature] = np.concatenate(arrays) + elif feature == "position_ids" and self.squash_position_ids: + arrays = [ + np.array(item[feature]) for item in features_ if feature in item + ] + # concatenate, get total length and create arange of new total position ids + position_ids = np.concatenate(arrays) + total_length = position_ids.shape[0] + position_ids = np.arange(total_length) + out_features[i][feature] = position_ids else: arrays = [ np.array(item[feature]) for item in features_ if feature in item diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 949c76f49..029d991dd 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -5,8 +5,8 @@ import inspect import torch import torch.distributed as dist -from accelerate import PartialState from torch import nn +from torch.distributed import DeviceMesh from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput @@ -194,6 +194,7 @@ class SequenceParallelContextManager: ring_attn_func: RingAttnFunc, heads_k_stride: int | None, gather_outputs: bool, + device_mesh: DeviceMesh | None = None, ): self.models = models self.context_parallel_size = context_parallel_size @@ -201,6 +202,7 @@ class SequenceParallelContextManager: self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride self.gather_outputs = gather_outputs + self.device_mesh = device_mesh self._register_ring_attn() @@ -240,9 +242,8 @@ class SequenceParallelContextManager: def _register_ring_attn(self): # Initialize ring attn for sequence parallelism - partial_state = PartialState() register_ring_attn_from_device_mesh( - device_mesh=partial_state.device_mesh, + device_mesh=self.device_mesh, context_parallel_dim=("cp",), heads_k_stride=self.heads_k_stride, ring_attn_func=self.ring_attn_func, diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 2192e7b9d..d2d1075cb 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -8,6 +8,7 @@ from datetime import timedelta import torch import torch.distributed as dist from accelerate import PartialState +from accelerate.utils import ParallelismConfig from transformers.utils.import_utils import ( is_torch_cuda_available, is_torch_mps_available, @@ -290,3 +291,77 @@ def reduce_and_broadcast(fn1, fn2): # Use compute_and_broadcast to compute the reduced value on the main process # and then broadcast it to all ranks return compute_and_broadcast(lambda: fn2(gathered_values)) + + +def build_parallelism_config(cfg): + pc_kwargs = _get_parallel_config_kwargs( + get_world_size(), + cfg.tensor_parallel_size, + cfg.context_parallel_size, + cfg.dp_shard_size, + cfg.dp_replicate_size, + bool(cfg.fsdp or cfg.fsdp_config), + ) + + if pc_kwargs: + parallelism_config = ParallelismConfig( + **pc_kwargs, + ) + device_mesh = parallelism_config.build_device_mesh("cuda") + + return parallelism_config, device_mesh + return None, None + + +def _get_parallel_config_kwargs( + world_size: int, + tensor_parallel_size: int = 1, + context_parallel_size: int = 1, + dp_shard_size: int | None = None, + dp_replicate_size: int | None = None, + is_fsdp: bool = False, +): + pc_kwargs = {} + remaining_world_size = world_size + + if tensor_parallel_size and tensor_parallel_size > 1: + pc_kwargs["tp_size"] = tensor_parallel_size + remaining_world_size = remaining_world_size // tensor_parallel_size + + if context_parallel_size and context_parallel_size > 1: + pc_kwargs["cp_size"] = context_parallel_size + remaining_world_size = remaining_world_size // context_parallel_size + + if dp_shard_size is None and dp_replicate_size in (None, 1): + if remaining_world_size > 1: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if dp_replicate_size and dp_replicate_size > 1: + pc_kwargs["dp_replicate_size"] = dp_replicate_size + remaining_world_size = remaining_world_size // dp_replicate_size + + if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: + if not is_fsdp: + raise ValueError( + "dp_shard_size was configured without a corresponding fsdp_config! " + "Please ensure you have configured FSDP using fsdp_config." + ) + pc_kwargs["dp_shard_size"] = dp_shard_size + remaining_world_size = remaining_world_size // dp_shard_size + if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs: + pc_kwargs["dp_replicate_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + if "dp_shard_size" not in pc_kwargs and is_fsdp: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + raise ValueError( + f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n" + f"{pc_kwargs}" + ) + + return pc_kwargs diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index da8c545bc..d9c8042d4 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -118,6 +118,18 @@ class SFTDataset(BaseModel): "description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).' }, ) + field_thinking: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the reasoning trace (default: "reasoning_content").' + }, + ) + template_thinking_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key the chat template expects that indicates the reasoning trace." + }, + ) # deprecated, use message_property_mappings message_field_role: str | None = None # deprecated, use message_property_mappings diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index ac3355f74..72991c947 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1147,6 +1147,19 @@ class ModelCompatibilityValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_gpt_oss_fsdp_loading(cls, data): + if data.get("model_quantization_config", "") == "Mxfp4Config": + if ( + data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) + is True + ): + raise ValueError( + "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." + ) + return data + class ComplexValidationMixin: """Complex validation methods that involve multiple systems.""" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 90ae1a889..26634cbbe 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -597,6 +597,25 @@ def setup_fsdp_envs(cfg): os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true" +def setup_parallelism_envs(cfg): + set_accelerate_parallelism_config = False + if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size) + if cfg.dp_shard_size and cfg.dp_shard_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size) + if cfg.dp_replicate_size and cfg.dp_replicate_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size) + if cfg.context_parallel_size and cfg.context_parallel_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size) + os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true" + if set_accelerate_parallelism_config: + os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true" + + def prepare_optim_env(cfg): if not check_cuda_p2p_ib_support(): if os.getenv("NCCL_P2P_DISABLE") is None: @@ -615,6 +634,7 @@ def prepare_optim_env(cfg): stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) setup_deepspeed_env(cfg, stage=stage) + setup_parallelism_envs(cfg) setup_torch_compile_env(cfg) if cfg.fp8: diff --git a/tests/test_loaders.py b/tests/test_loaders.py index def7672b9..d45f41998 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available from axolotl.loaders import ModelLoader from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import _get_parallel_config_kwargs class TestModelsUtils: @@ -193,15 +194,13 @@ class TestModelsUtils: is_fsdp, expected, ): - res = ( - ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access - world_size, - tensor_parallel_size, - context_parallel_size, - dp_shard_size, - dp_replicate_size, - is_fsdp, - ) + res = _get_parallel_config_kwargs( # pylint: disable=protected-access + world_size, + tensor_parallel_size, + context_parallel_size, + dp_shard_size, + dp_replicate_size, + is_fsdp, ) if expected[0] > 1: