From 396a7a74fc6a5d73c4d2d555810ed516f456c479 Mon Sep 17 00:00:00 2001 From: Jan Philipp Harries <2862336+jphme@users.noreply.github.com> Date: Thu, 31 Aug 2023 19:37:47 +0200 Subject: [PATCH] Added advanced DDP args (#515) * add ddp_config * add advanced ddp config * add ddp_config * add advanced ddp config --------- Co-authored-by: Jan Philipp Harries --- README.md | 5 +++++ src/axolotl/utils/trainer.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/README.md b/README.md index 204e2141a..19b164a8b 100644 --- a/README.md +++ b/README.md @@ -623,6 +623,11 @@ fsdp_config: # Deepspeed config path deepspeed: +# Advanced DDP Arguments +ddp_timeout: +ddp_bucket_cap_mb: +ddp_broadcast_buffers: + # Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0aceee519..f0669565f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -579,6 +579,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ if cfg.bench_dataset: training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset + # DDP Config + if cfg.ddp_timeout: + training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout + # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + if cfg.ddp_bucket_cap_mb: + training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb + if cfg.ddp_broadcast_buffers is not None: + training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len,