Added advanced DDP args (#515)
* add ddp_config * add advanced ddp config * add ddp_config * add advanced ddp config --------- Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b21e4a20fe
commit
396a7a74fc
@@ -623,6 +623,11 @@ fsdp_config:
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
|
||||
@@ -579,6 +579,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||
|
||||
# DDP Config
|
||||
if cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
|
||||
if cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
max_seq_length=cfg.sequence_len,
|
||||
|
||||
Reference in New Issue
Block a user