Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
53ce90d21e add sync_model_states parameter to fix resume from checkpoint with fsdp
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
fix formatting for linter
fixes FSDP resume from checkpoint (unpacked only)
chore: fix linter
chore: lint
2023-08-30 21:15:50 -07:00
4 changed files with 59 additions and 0 deletions

View File

@@ -665,6 +665,7 @@ fsdp:
fsdp_config: fsdp_config:
fsdp_offload_params: true fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```

View File

@@ -0,0 +1,45 @@
"""
Monkeypatch to fix fsdp set state when no previous state was set
"""
import contextlib
from typing import Generator, Optional
import torch
from torch import nn
from torch.distributed.fsdp.api import (
OptimStateDictConfig,
StateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
@staticmethod
@contextlib.contextmanager
def state_dict_type_patch(
module: nn.Module,
state_dict_type: StateDictType,
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> Generator:
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
module,
state_dict_type,
state_dict_config,
optim_state_dict_config,
)
yield
if prev_state_dict_settings.state_dict_type:
FullyShardedDataParallel.set_state_dict_type(
module,
prev_state_dict_settings.state_dict_type,
prev_state_dict_settings.state_dict_config,
prev_state_dict_settings.optim_state_dict_config,
)
def replace_fsdp_state_dict_type():
torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
state_dict_type_patch
)

View File

@@ -152,6 +152,16 @@ def validate_config(cfg):
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models") raise ValueError("FSDP is not supported for falcon models")
if (
cfg.fsdp
and cfg.fsdp_config
and cfg.fsdp_config.fsdp_state_dict_type
and not cfg.fsdp_config.fsdp_sync_module_states
):
LOG.warning(
"We recommend setting fsdp_config.fsdp_sync_module_states to `true`"
)
if ( if (
cfg.base_model and "mpt" in cfg.base_model.lower() cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing: ) and cfg.gradient_checkpointing:

View File

@@ -471,6 +471,9 @@ def setup_fsdp_envs(cfg):
os.environ[ os.environ[
"FSDP_TRANSFORMER_CLS_TO_WRAP" "FSDP_TRANSFORMER_CLS_TO_WRAP"
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
from axolotl.monkeypatch.fsdp import replace_fsdp_state_dict_type
replace_fsdp_state_dict_type()
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):