fix for accelerator state getting reset and missing schema
This commit is contained in:
@@ -27,6 +27,7 @@ import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
@@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
use_configured_state = True
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=True,
|
||||
)
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
|
||||
@@ -415,6 +415,26 @@ class ModelLoader:
|
||||
device_mesh = torch.distributed.init_device_mesh(
|
||||
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
submeshes = [
|
||||
tuple(parallelism_config.dp_dim_names),
|
||||
tuple(parallelism_config.dp_shard_cp_dim_names),
|
||||
tuple(parallelism_config.dp_cp_dim_names),
|
||||
]
|
||||
submesh_names = [
|
||||
# create a submesh which is only used for distributing data across data parallel dims (no comms)
|
||||
"dp",
|
||||
# create a submesh which is used *just* for FSDP parameter gathering/scattering
|
||||
# and gradients reduce-scattering
|
||||
"dp_shard_cp",
|
||||
# create a submesh which is used for correctly reducing loss across data replica/context parallel
|
||||
"dp_cp",
|
||||
]
|
||||
for submesh, submesh_name in zip(submeshes, submesh_names):
|
||||
if submesh:
|
||||
device_mesh[submesh]._flatten( # pylint: disable=protected-access
|
||||
submesh_name
|
||||
)
|
||||
|
||||
PartialState().parallelism_config = parallelism_config
|
||||
PartialState().device_mesh = device_mesh
|
||||
|
||||
|
||||
@@ -644,6 +644,18 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
dp_shard_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||
},
|
||||
)
|
||||
sequence_parallel_degree: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Deprecated: use `context_parallel_size` instead"
|
||||
},
|
||||
)
|
||||
context_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user