From c67fb7158312e47e3326f077f74485cf0a23b51a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 Jan 2024 18:13:29 -0500 Subject: [PATCH] Peft deepspeed resume (#1227) * import deepspeed integration * monkeypatch peft adapater with deepspeed for resume from checkpoint * fix patch * fix patches attempt 2 * make sure to set lora_model_dir * skip pylint for deepspeed.utils * pick up upstream fix in transformers * remove monkeypatch for deepspeed/peft fix * no need to set the lora_model_dir on resume * unset load_in_*bit when using quant config * guard before del * better handling of load_in* kwargs --- requirements.txt | 2 +- src/axolotl/cli/train.py | 7 ++++--- src/axolotl/train.py | 30 +++++++++++++++--------------- src/axolotl/utils/models.py | 22 +++++++++++++++------- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/requirements.txt b/requirements.txt index 075f63622..2e978c16d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft @ git+https://github.com/huggingface/peft.git -transformers==4.37.0 +transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.26.1 diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7ab06422f..6cbe2c960 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -6,8 +6,9 @@ from pathlib import Path from typing import Tuple import fire -import transformers -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.hf_argparser import HfArgumentParser +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, @@ -27,7 +28,7 @@ LOG = logging.getLogger("axolotl.cli.train") def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index fb163a9d1..d0f58bca9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -57,6 +57,21 @@ def train( eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps + if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: + possible_checkpoints = [ + str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") + ] + if len(possible_checkpoints) > 0: + sorted_paths = sorted( + possible_checkpoints, + key=lambda path: int(path.split("-")[-1]), + ) + cfg.resume_from_checkpoint = sorted_paths[-1] + LOG.info( + f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + ) + resume_from_checkpoint = cfg.resume_from_checkpoint + # Load the model and tokenizer msg = "loading model" if cfg.adapter: @@ -79,21 +94,6 @@ def train( safe_serialization = cfg.save_safetensors is True - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - resume_from_checkpoint = cfg.resume_from_checkpoint - if cfg.unfrozen_parameters: freeze_parameters_except(model, cfg.unfrozen_parameters) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bee6af373..ad3ac7bc6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -473,6 +473,18 @@ def load_model( **bnb_config, ) + if cfg.load_in_8bit and cfg.adapter is not None: + model_kwargs["load_in_8bit"] = True + if cfg.load_in_4bit and cfg.adapter is not None: + model_kwargs["load_in_4bit"] = True + + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + if "quantization_config" in model_kwargs or cfg.gptq: + if "load_in_8bit" in model_kwargs: + del model_kwargs["load_in_8bit"] + if "load_in_4bit" in model_kwargs: + del model_kwargs["load_in_4bit"] + # sample packing uses custom FA2 patch if cfg.flash_attention: if not cfg.sample_packing: @@ -506,8 +518,6 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) @@ -575,8 +585,6 @@ def load_model( model = getattr(transformers, model_type).from_pretrained( base_model, config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -608,8 +616,6 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -678,7 +684,9 @@ def load_model( skip_prepare_model_for_kbit_training = False if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled(): - from deepspeed.utils import set_z3_leaf_modules + from deepspeed.utils import ( # pylint: disable=no-name-in-module + set_z3_leaf_modules, + ) from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock set_z3_leaf_modules(model, [MixtralSparseMoeBlock])