Merge pull request #159 from AngainorDev/patch-1
Fix training over existing lora
This commit is contained in:
@@ -165,7 +165,7 @@ def train(
|
|||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
for k, _ in kwargs.items():
|
for k, _ in kwargs.items():
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||||
if k in cfg_keys or cfg.strict is False:
|
if k in cfg_keys or not cfg.strict:
|
||||||
# handle booleans
|
# handle booleans
|
||||||
if isinstance(cfg[k], bool):
|
if isinstance(cfg[k], bool):
|
||||||
cfg[k] = bool(kwargs[k])
|
cfg[k] = bool(kwargs[k])
|
||||||
@@ -205,8 +205,8 @@ def train(
|
|||||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||||
|
|
||||||
if check_not_in(
|
if (
|
||||||
["inference", "shard", "merge_lora"], kwargs
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
): # don't need to load dataset for these
|
): # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -234,7 +234,6 @@ def train(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
adapter=cfg.adapter,
|
adapter=cfg.adapter,
|
||||||
inference=("inference" in kwargs),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
@@ -247,7 +246,7 @@ def train(
|
|||||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
return
|
return
|
||||||
|
|
||||||
if "inference" in kwargs:
|
if cfg.inference:
|
||||||
logging.info("calling do_inference function")
|
logging.info("calling do_inference function")
|
||||||
inf_kwargs: Dict[str, Any] = {}
|
inf_kwargs: Dict[str, Any] = {}
|
||||||
if "prompter" in kwargs:
|
if "prompter" in kwargs:
|
||||||
|
|||||||
@@ -77,15 +77,9 @@ def load_tokenizer(
|
|||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
base_model,
|
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||||
base_model_config,
|
|
||||||
model_type,
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
adapter="lora",
|
|
||||||
inference=False,
|
|
||||||
):
|
):
|
||||||
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
"""
|
"""
|
||||||
Load a model from a base model and a model type.
|
Load a model from a base model and a model type.
|
||||||
"""
|
"""
|
||||||
@@ -98,7 +92,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
logging.info("patching with flash attention")
|
||||||
@@ -439,6 +433,7 @@ def load_lora(model, cfg):
|
|||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
|
is_trainable=not cfg.inference,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
# torch_dtype=torch.float16,
|
# torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user