Compare commits
1 Commits
flex_patch
...
fsdp-defau
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53ce90d21e |
@@ -665,6 +665,7 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_offload_params: true
|
fsdp_offload_params: true
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
45
src/axolotl/monkeypatch/fsdp.py
Normal file
45
src/axolotl/monkeypatch/fsdp.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch to fix fsdp set state when no previous state was set
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from typing import Generator, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributed.fsdp.api import (
|
||||||
|
OptimStateDictConfig,
|
||||||
|
StateDictConfig,
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def state_dict_type_patch(
|
||||||
|
module: nn.Module,
|
||||||
|
state_dict_type: StateDictType,
|
||||||
|
state_dict_config: Optional[StateDictConfig] = None,
|
||||||
|
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||||
|
) -> Generator:
|
||||||
|
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
|
||||||
|
module,
|
||||||
|
state_dict_type,
|
||||||
|
state_dict_config,
|
||||||
|
optim_state_dict_config,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
if prev_state_dict_settings.state_dict_type:
|
||||||
|
FullyShardedDataParallel.set_state_dict_type(
|
||||||
|
module,
|
||||||
|
prev_state_dict_settings.state_dict_type,
|
||||||
|
prev_state_dict_settings.state_dict_config,
|
||||||
|
prev_state_dict_settings.optim_state_dict_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_fsdp_state_dict_type():
|
||||||
|
torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
|
||||||
|
state_dict_type_patch
|
||||||
|
)
|
||||||
@@ -152,6 +152,16 @@ def validate_config(cfg):
|
|||||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||||
raise ValueError("FSDP is not supported for falcon models")
|
raise ValueError("FSDP is not supported for falcon models")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.fsdp
|
||||||
|
and cfg.fsdp_config
|
||||||
|
and cfg.fsdp_config.fsdp_state_dict_type
|
||||||
|
and not cfg.fsdp_config.fsdp_sync_module_states
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"We recommend setting fsdp_config.fsdp_sync_module_states to `true`"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||||
) and cfg.gradient_checkpointing:
|
) and cfg.gradient_checkpointing:
|
||||||
|
|||||||
@@ -471,6 +471,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
os.environ[
|
os.environ[
|
||||||
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
from axolotl.monkeypatch.fsdp import replace_fsdp_state_dict_type
|
||||||
|
|
||||||
|
replace_fsdp_state_dict_type()
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
|
|||||||
Reference in New Issue
Block a user