From 2c1cb8b300d05d1a41b5821dbe4ec2656aa0f86c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 23 Jul 2025 08:43:34 -0400 Subject: [PATCH] fix for accelerator state getting reset and missing schema --- src/axolotl/core/builders/base.py | 13 ++++++++++++- src/axolotl/loaders/model.py | 20 ++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 12 ++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index d3a3b3242..7413168b8 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -27,6 +27,7 @@ import torch from transformers import ( TrainerCallback, ) +from transformers.trainer_pt_utils import AcceleratorConfig from transformers.training_args import OptimizerNames from axolotl.integrations.base import PluginManager @@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode def _configure_accelerator_config(self, training_args_kwargs: dict): + use_configured_state = True if self.cfg.accelerator_config: - training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config + use_configured_state = self.cfg.accelerator_config.pop( + "use_configured_state", use_configured_state + ) + training_args_kwargs["accelerator_config"] = AcceleratorConfig( + use_configured_state=use_configured_state, **self.cfg.accelerator_config + ) + else: + training_args_kwargs["accelerator_config"] = AcceleratorConfig( + use_configured_state=True, + ) def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.activation_offloading is True: diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 4e440c8a6..313e89a9d 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -415,6 +415,26 @@ class ModelLoader: device_mesh = torch.distributed.init_device_mesh( "cuda", mesh_shape, mesh_dim_names=mesh_dim_names ) + submeshes = [ + tuple(parallelism_config.dp_dim_names), + tuple(parallelism_config.dp_shard_cp_dim_names), + tuple(parallelism_config.dp_cp_dim_names), + ] + submesh_names = [ + # create a submesh which is only used for distributing data across data parallel dims (no comms) + "dp", + # create a submesh which is used *just* for FSDP parameter gathering/scattering + # and gradients reduce-scattering + "dp_shard_cp", + # create a submesh which is used for correctly reducing loss across data replica/context parallel + "dp_cp", + ] + for submesh, submesh_name in zip(submeshes, submesh_names): + if submesh: + device_mesh[submesh]._flatten( # pylint: disable=protected-access + submesh_name + ) + PartialState().parallelism_config = parallelism_config PartialState().device_mesh = device_mesh diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index fb6de2b5a..a60e1659f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -644,6 +644,18 @@ class AxolotlInputConfig( }, ) + dp_shard_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of devices to shard across. If not set, will use all available devices." + }, + ) + sequence_parallel_degree: int | None = Field( + default=None, + json_schema_extra={ + "description": "Deprecated: use `context_parallel_size` instead" + }, + ) context_parallel_size: int | None = Field( default=None, json_schema_extra={