Fix Deepspeed loading (#950)
* add check for zero3 * freeze parameters * fixes for deepspeed loading * fix model parameter check * unfrozen parameters in example mixtral and logging when unfreezing
This commit is contained in:
39
deepspeed/zero3_bf16.json
Normal file
39
deepspeed/zero3_bf16.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared
|
|||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
|
unfrozen_parameters:
|
||||||
|
# - lm_head.*
|
||||||
|
# - model.embed_tokens.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.monkeypatch import neft_embeddings
|
from axolotl.monkeypatch import neft_embeddings
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -78,6 +79,9 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|||||||
38
src/axolotl/utils/freeze.py
Normal file
38
src/axolotl/utils/freeze.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
module to freeze/unfreeze parameters by name
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters_except(model, regex_patterns):
|
||||||
|
"""
|
||||||
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None; the model is modified in place.
|
||||||
|
"""
|
||||||
|
# Escape periods and compile the regex patterns
|
||||||
|
compiled_patterns = [
|
||||||
|
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# First, freeze all parameters in the model
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Unfreeze layers that match the regex patterns
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"unfreezing {name}")
|
||||||
|
param.requires_grad = True
|
||||||
@@ -21,6 +21,7 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
@@ -285,6 +286,9 @@ def load_model(
|
|||||||
model_kwargs["max_memory"] = cfg.max_memory
|
model_kwargs["max_memory"] = cfg.max_memory
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
|
|||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
|
|||||||
Reference in New Issue
Block a user