fix for accelerator state getting reset and missing schema
This commit is contained in:
@@ -27,6 +27,7 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
|
use_configured_state = True
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
use_configured_state = self.cfg.accelerator_config.pop(
|
||||||
|
"use_configured_state", use_configured_state
|
||||||
|
)
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
|
|||||||
@@ -415,6 +415,26 @@ class ModelLoader:
|
|||||||
device_mesh = torch.distributed.init_device_mesh(
|
device_mesh = torch.distributed.init_device_mesh(
|
||||||
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
||||||
)
|
)
|
||||||
|
submeshes = [
|
||||||
|
tuple(parallelism_config.dp_dim_names),
|
||||||
|
tuple(parallelism_config.dp_shard_cp_dim_names),
|
||||||
|
tuple(parallelism_config.dp_cp_dim_names),
|
||||||
|
]
|
||||||
|
submesh_names = [
|
||||||
|
# create a submesh which is only used for distributing data across data parallel dims (no comms)
|
||||||
|
"dp",
|
||||||
|
# create a submesh which is used *just* for FSDP parameter gathering/scattering
|
||||||
|
# and gradients reduce-scattering
|
||||||
|
"dp_shard_cp",
|
||||||
|
# create a submesh which is used for correctly reducing loss across data replica/context parallel
|
||||||
|
"dp_cp",
|
||||||
|
]
|
||||||
|
for submesh, submesh_name in zip(submeshes, submesh_names):
|
||||||
|
if submesh:
|
||||||
|
device_mesh[submesh]._flatten( # pylint: disable=protected-access
|
||||||
|
submesh_name
|
||||||
|
)
|
||||||
|
|
||||||
PartialState().parallelism_config = parallelism_config
|
PartialState().parallelism_config = parallelism_config
|
||||||
PartialState().device_mesh = device_mesh
|
PartialState().device_mesh = device_mesh
|
||||||
|
|
||||||
|
|||||||
@@ -644,6 +644,18 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dp_shard_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sequence_parallel_degree: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Deprecated: use `context_parallel_size` instead"
|
||||||
|
},
|
||||||
|
)
|
||||||
context_parallel_size: int | None = Field(
|
context_parallel_size: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user