fix fsdp training args
This commit is contained in:
@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
||||
else:
|
||||
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||
if cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
|
||||
if cfg.fsdp_transformer_layer_cls_to_wrap:
|
||||
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
|
||||
|
||||
|
||||
# deepspeed
|
||||
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
||||
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
||||
fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
|
||||
fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user