Refactor train cfg cli (#499)
* wip to cleanup cfg cli options * fix launcher * fix cli args
This commit is contained in:
@@ -5,12 +5,13 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
from typing import Optional, Tuple # noqa: F401
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftConfig
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -23,13 +24,17 @@ from transformers import ( # noqa: F401
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftConfig # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code: bool = False or cfg.trust_remote_code
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
@@ -86,8 +91,10 @@ def load_tokenizer(cfg):
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg, tokenizer
|
||||
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
inference: bool = False,
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
@@ -97,14 +104,9 @@ def load_model(
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
cfg.is_llama_derived_model = (
|
||||
"llama" in base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or cfg.is_llama_derived_model
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
@@ -146,7 +148,7 @@ def load_model(
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||
and not cfg.inference
|
||||
and not inference
|
||||
):
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
@@ -424,15 +426,15 @@ def load_model(
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
if adapter == "llama-adapter":
|
||||
return load_llama_adapter(model, cfg)
|
||||
|
||||
@@ -478,8 +480,8 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def load_lora(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
def load_lora(model, cfg, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -506,7 +508,7 @@ def load_lora(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
is_trainable=not cfg.inference,
|
||||
is_trainable=(not inference),
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
Reference in New Issue
Block a user