fix fsdp training args

This commit is contained in:
Wing Lian
2023-04-30 00:56:28 -04:00
parent 78821815de
commit 29936bba7f

View File

@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
else:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
if cfg.fsdp_transformer_layer_cls_to_wrap:
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
# deepspeed
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
**training_arguments_kwargs,
)