fix fsdp training args
This commit is contained in:
@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||||
|
if cfg.fsdp:
|
||||||
|
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
|
||||||
|
if cfg.fsdp_transformer_layer_cls_to_wrap:
|
||||||
|
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
# deepspeed
|
# deepspeed
|
||||||
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
||||||
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
||||||
fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
|
|
||||||
fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
|
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user