From 53ce90d21e991ca889b79009e6dcc2423345c124 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Aug 2023 01:01:01 -0400 Subject: [PATCH] add sync_model_states parameter to fix resume from checkpoint with fsdp fix formatting for linter fixes FSDP resume from checkpoint (unpacked only) chore: fix linter chore: lint --- README.md | 1 + src/axolotl/monkeypatch/fsdp.py | 45 +++++++++++++++++++++++++++++++++ src/axolotl/utils/config.py | 10 ++++++++ src/axolotl/utils/trainer.py | 3 +++ 4 files changed, 59 insertions(+) create mode 100644 src/axolotl/monkeypatch/fsdp.py diff --git a/README.md b/README.md index 204e2141a..bf80790ed 100644 --- a/README.md +++ b/README.md @@ -665,6 +665,7 @@ fsdp: fsdp_config: fsdp_offload_params: true fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` diff --git a/src/axolotl/monkeypatch/fsdp.py b/src/axolotl/monkeypatch/fsdp.py new file mode 100644 index 000000000..2abd25154 --- /dev/null +++ b/src/axolotl/monkeypatch/fsdp.py @@ -0,0 +1,45 @@ +""" +Monkeypatch to fix fsdp set state when no previous state was set +""" + +import contextlib +from typing import Generator, Optional + +import torch +from torch import nn +from torch.distributed.fsdp.api import ( + OptimStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + +@staticmethod +@contextlib.contextmanager +def state_dict_type_patch( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, +) -> Generator: + prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ) + yield + if prev_state_dict_settings.state_dict_type: + FullyShardedDataParallel.set_state_dict_type( + module, + prev_state_dict_settings.state_dict_type, + prev_state_dict_settings.state_dict_config, + prev_state_dict_settings.optim_state_dict_config, + ) + + +def replace_fsdp_state_dict_type(): + torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = ( + state_dict_type_patch + ) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index abb3154d2..d471108cf 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -152,6 +152,16 @@ def validate_config(cfg): if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: raise ValueError("FSDP is not supported for falcon models") + if ( + cfg.fsdp + and cfg.fsdp_config + and cfg.fsdp_config.fsdp_state_dict_type + and not cfg.fsdp_config.fsdp_sync_module_states + ): + LOG.warning( + "We recommend setting fsdp_config.fsdp_sync_module_states to `true`" + ) + if ( cfg.base_model and "mpt" in cfg.base_model.lower() ) and cfg.gradient_checkpointing: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 37578908e..23bcd104e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -471,6 +471,9 @@ def setup_fsdp_envs(cfg): os.environ[ "FSDP_TRANSFORMER_CLS_TO_WRAP" ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap + from axolotl.monkeypatch.fsdp import replace_fsdp_state_dict_type + + replace_fsdp_state_dict_type() def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):