more checks and fixes for deepspeed and fsdp (#1208) [skip ci]
This commit is contained in:
@@ -15,15 +15,6 @@
|
|||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -19,15 +19,6 @@
|
|||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -23,15 +23,6 @@
|
|||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -23,15 +23,6 @@
|
|||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ def normalize_config(cfg):
|
|||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
cfg.save_steps = save_steps
|
cfg.save_steps = save_steps
|
||||||
if cfg.evals_per_epoch:
|
if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
|
||||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||||
if eval_steps < 1.0: # prevent evals on every step
|
if eval_steps < 1.0: # prevent evals on every step
|
||||||
cfg.eval_steps = eval_steps
|
cfg.eval_steps = eval_steps
|
||||||
@@ -485,35 +485,43 @@ def validate_config(cfg):
|
|||||||
"`use_reentrant` must be false when used with partially frozen model."
|
"`use_reentrant` must be false when used with partially frozen model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||||
contents = file.read()
|
contents = file.read()
|
||||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||||
if (
|
if cfg.flash_attention:
|
||||||
deepspeed_cfg.zero_optimization
|
if (
|
||||||
and deepspeed_cfg.zero_optimization.stage == 3
|
deepspeed_cfg.zero_optimization
|
||||||
):
|
and deepspeed_cfg.zero_optimization.stage == 3
|
||||||
if not (
|
|
||||||
(
|
|
||||||
deepspeed_cfg.bf16
|
|
||||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
deepspeed_cfg.fp16
|
|
||||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
raise ValueError(
|
if not (
|
||||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
(
|
||||||
)
|
deepspeed_cfg.bf16
|
||||||
|
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
deepspeed_cfg.fp16
|
||||||
|
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||||
|
)
|
||||||
|
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||||
|
LOG.warning(
|
||||||
|
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.test_datasets and cfg.val_set_size:
|
if cfg.test_datasets and cfg.val_set_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||||
|
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -642,15 +642,17 @@ def load_model(
|
|||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
||||||
for name, module in model.named_modules():
|
if not cfg.fsdp:
|
||||||
if any(m in name for m in ["norm", "gate"]):
|
# FSDP doesn't like mixed Float and BFloat16
|
||||||
module.to(torch.float32)
|
for name, module in model.named_modules():
|
||||||
if model_config.model_type == "btlm":
|
if any(m in name for m in ["norm", "gate"]):
|
||||||
# don't upcast lm_head for btlm
|
|
||||||
continue
|
|
||||||
if any(m in name for m in embedding_modules):
|
|
||||||
if hasattr(module, "weight"):
|
|
||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
if model_config.model_type == "btlm":
|
||||||
|
# don't upcast lm_head for btlm
|
||||||
|
continue
|
||||||
|
if any(m in name for m in embedding_modules):
|
||||||
|
if hasattr(module, "weight"):
|
||||||
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
|
|||||||
Reference in New Issue
Block a user