From e923e62d2495766c21c11ff131fddec6fc927ebc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Jan 2024 20:01:45 -0500 Subject: [PATCH] more checks and fixes for deepspeed and fsdp (#1208) [skip ci] --- deepspeed_configs/zero1.json | 9 ------ deepspeed_configs/zero2.json | 9 ------ deepspeed_configs/zero3.json | 9 ------ deepspeed_configs/zero3_bf16.json | 9 ------ src/axolotl/utils/config.py | 48 ++++++++++++++++++------------- src/axolotl/utils/models.py | 18 ++++++------ 6 files changed, 38 insertions(+), 64 deletions(-) diff --git a/deepspeed_configs/zero1.json b/deepspeed_configs/zero1.json index c76a20637..787fc0d6b 100644 --- a/deepspeed_configs/zero1.json +++ b/deepspeed_configs/zero1.json @@ -15,15 +15,6 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", diff --git a/deepspeed_configs/zero2.json b/deepspeed_configs/zero2.json index 3f3baa3ec..5b22d996c 100644 --- a/deepspeed_configs/zero2.json +++ b/deepspeed_configs/zero2.json @@ -19,15 +19,6 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", diff --git a/deepspeed_configs/zero3.json b/deepspeed_configs/zero3.json index cf64e83ac..a185afab4 100644 --- a/deepspeed_configs/zero3.json +++ b/deepspeed_configs/zero3.json @@ -23,15 +23,6 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json index 42d10b6bd..263caa393 100644 --- a/deepspeed_configs/zero3_bf16.json +++ b/deepspeed_configs/zero3_bf16.json @@ -23,15 +23,6 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index dcd795099..c27849d83 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -95,7 +95,7 @@ def normalize_config(cfg): save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step cfg.save_steps = save_steps - if cfg.evals_per_epoch: + if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch: eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs) if eval_steps < 1.0: # prevent evals on every step cfg.eval_steps = eval_steps @@ -485,35 +485,43 @@ def validate_config(cfg): "`use_reentrant` must be false when used with partially frozen model." ) - if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file(): + if cfg.deepspeed and Path(cfg.deepspeed).is_file(): with open(cfg.deepspeed, encoding="utf-8") as file: contents = file.read() deepspeed_cfg: DictDefault = DictDefault(json.loads(contents)) - if ( - deepspeed_cfg.zero_optimization - and deepspeed_cfg.zero_optimization.stage == 3 - ): - if not ( - ( - deepspeed_cfg.bf16 - and deepspeed_cfg.bf16.enabled # pylint: disable=no-member - is True - ) - or ( - deepspeed_cfg.fp16 - and deepspeed_cfg.fp16.enabled # pylint: disable=no-member - is True - ) + if cfg.flash_attention: + if ( + deepspeed_cfg.zero_optimization + and deepspeed_cfg.zero_optimization.stage == 3 ): - raise ValueError( - "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" - ) + if not ( + ( + deepspeed_cfg.bf16 + and deepspeed_cfg.bf16.enabled # pylint: disable=no-member + is True + ) + or ( + deepspeed_cfg.fp16 + and deepspeed_cfg.fp16.enabled # pylint: disable=no-member + is True + ) + ): + raise ValueError( + "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" + ) + if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer: + LOG.warning( + f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer." + ) if cfg.test_datasets and cfg.val_set_size: raise ValueError( "non-zero val_set_size should not be used with test_datasets configuration" ) + if cfg.fsdp and "bnb" in cfg.optimizer: + raise ValueError(f"FSDP not compatible with {cfg.optimizer}") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7ca9abbb5..72427f645 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -642,15 +642,17 @@ def load_model( # make sure these are fp32 per Ramesh et al. (2021) embedding_modules = get_linear_embedding_layers(cfg.model_config_type) - for name, module in model.named_modules(): - if any(m in name for m in ["norm", "gate"]): - module.to(torch.float32) - if model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): + if not cfg.fsdp: + # FSDP doesn't like mixed Float and BFloat16 + for name, module in model.named_modules(): + if any(m in name for m in ["norm", "gate"]): module.to(torch.float32) + if model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue + if any(m in name for m in embedding_modules): + if hasattr(module, "weight"): + module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp skip_prepare_model_for_kbit_training = False