diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a7abe38b6..4cf092ecd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -377,6 +377,10 @@ def setup_fsdp_envs(cfg): os.environ["FSDP_SYNC_MODULE_STATES"] = "true" if cfg.fsdp_config.fsdp_state_dict_type: os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type + if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: + os.environ[ + "FSDP_TRANSFORMER_CLS_TO_WRAP" + ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):