fix for accelerator state getting reset and missing schema

This commit is contained in:
Wing Lian
2025-07-23 08:43:34 -04:00
parent cca207eec4
commit 2c1cb8b300
3 changed files with 44 additions and 1 deletions

View File

@@ -27,6 +27,7 @@ import torch
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
@@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
use_configured_state = True
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=True,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:

View File

@@ -415,6 +415,26 @@ class ModelLoader:
device_mesh = torch.distributed.init_device_mesh(
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
)
submeshes = [
tuple(parallelism_config.dp_dim_names),
tuple(parallelism_config.dp_shard_cp_dim_names),
tuple(parallelism_config.dp_cp_dim_names),
]
submesh_names = [
# create a submesh which is only used for distributing data across data parallel dims (no comms)
"dp",
# create a submesh which is used *just* for FSDP parameter gathering/scattering
# and gradients reduce-scattering
"dp_shard_cp",
# create a submesh which is used for correctly reducing loss across data replica/context parallel
"dp_cp",
]
for submesh, submesh_name in zip(submeshes, submesh_names):
if submesh:
device_mesh[submesh]._flatten( # pylint: disable=protected-access
submesh_name
)
PartialState().parallelism_config = parallelism_config
PartialState().device_mesh = device_mesh

View File

@@ -644,6 +644,18 @@ class AxolotlInputConfig(
},
)
dp_shard_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of devices to shard across. If not set, will use all available devices."
},
)
sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
context_parallel_size: int | None = Field(
default=None,
json_schema_extra={