From ad2b48c0fa61ff55a40279a360d491ebc78c024f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 30 Apr 2023 13:32:07 -0400 Subject: [PATCH] fdsp config dict fix, todo list, add torchdistx support --- TODO.md | 10 ++++++++++ src/axolotl/utils/models.py | 5 +++++ src/axolotl/utils/trainer.py | 12 +++++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..2002bbbaf --- /dev/null +++ b/TODO.md @@ -0,0 +1,10 @@ +# todo list + +- [] Validation of parameters for combinations that won't work + + + +## things that are known not to work + +- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203 +- adamw_bnb_8bit doesn't play well with FSDP offload diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 89d6f9d14..bd73fc76a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -179,6 +179,11 @@ def load_model( m.scales = m.scales.half() m.bias = m.bias.half() + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1: + model.is_parallelizable = True + model.model_parallel = True + + # TODO resume_from_checkpoint handling return model, tokenizer, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1da3fe845..4d4719969 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,5 +1,7 @@ +import importlib import math import os +import sys from pathlib import Path import bitsandbytes as bnb @@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else: training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing if cfg.fsdp: - training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ") - if cfg.fsdp_transformer_layer_cls_to_wrap: - training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap + training_arguments_kwargs["fsdp"] = cfg.fsdp + if cfg.fsdp_config: + training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) # deepspeed @@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_kwargs = {} + if cfg.optimizer == "adamw_anyprecision": + if Path(cfg.torchdistx_path).exists(): + sys.path.append(cfg.torchdistx_path) + torchdistx = importlib.import_module('torchdistx') if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs: decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name]