Peft deepspeed resume (#1227)
* import deepspeed integration * monkeypatch peft adapater with deepspeed for resume from checkpoint * fix patch * fix patches attempt 2 * make sure to set lora_model_dir * skip pylint for deepspeed.utils * pick up upstream fix in transformers * remove monkeypatch for deepspeed/peft fix * no need to set the lora_model_dir on resume * unset load_in_*bit when using quant config * guard before del * better handling of load_in* kwargs
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers==4.37.0
|
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.26.1
|
accelerate==0.26.1
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ from pathlib import Path
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -27,7 +28,7 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,6 +57,21 @@ def train(
|
|||||||
eval_dataset = dataset_meta.eval_dataset
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||||
|
possible_checkpoints = [
|
||||||
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
|
]
|
||||||
|
if len(possible_checkpoints) > 0:
|
||||||
|
sorted_paths = sorted(
|
||||||
|
possible_checkpoints,
|
||||||
|
key=lambda path: int(path.split("-")[-1]),
|
||||||
|
)
|
||||||
|
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||||
|
LOG.info(
|
||||||
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
|
)
|
||||||
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
msg = "loading model"
|
msg = "loading model"
|
||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
@@ -79,21 +94,6 @@ def train(
|
|||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
|
||||||
possible_checkpoints = [
|
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
|
||||||
]
|
|
||||||
if len(possible_checkpoints) > 0:
|
|
||||||
sorted_paths = sorted(
|
|
||||||
possible_checkpoints,
|
|
||||||
key=lambda path: int(path.split("-")[-1]),
|
|
||||||
)
|
|
||||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
|
||||||
LOG.info(
|
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
|
||||||
)
|
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
|
||||||
|
|
||||||
if cfg.unfrozen_parameters:
|
if cfg.unfrozen_parameters:
|
||||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
|
|||||||
@@ -473,6 +473,18 @@ def load_model(
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||||
|
model_kwargs["load_in_8bit"] = True
|
||||||
|
if cfg.load_in_4bit and cfg.adapter is not None:
|
||||||
|
model_kwargs["load_in_4bit"] = True
|
||||||
|
|
||||||
|
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||||
|
if "quantization_config" in model_kwargs or cfg.gptq:
|
||||||
|
if "load_in_8bit" in model_kwargs:
|
||||||
|
del model_kwargs["load_in_8bit"]
|
||||||
|
if "load_in_4bit" in model_kwargs:
|
||||||
|
del model_kwargs["load_in_4bit"]
|
||||||
|
|
||||||
# sample packing uses custom FA2 patch
|
# sample packing uses custom FA2 patch
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
if not cfg.sample_packing:
|
if not cfg.sample_packing:
|
||||||
@@ -506,8 +518,6 @@ def load_model(
|
|||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -575,8 +585,6 @@ def load_model(
|
|||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -608,8 +616,6 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -678,7 +684,9 @@ def load_model(
|
|||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
from deepspeed.utils import set_z3_leaf_modules
|
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||||
|
set_z3_leaf_modules,
|
||||||
|
)
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||||
|
|||||||
Reference in New Issue
Block a user