fix: allow merge lora on pre-quantized model (#2511)
* fix: allow merge lora on pre-quantized model * fix: remove unused sections per comment
This commit is contained in:
@@ -32,7 +32,13 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
model = model.merge_and_unload(progressbar=True)
|
model = model.merge_and_unload(progressbar=True)
|
||||||
model.to(dtype=cfg.torch_dtype)
|
try:
|
||||||
|
model.to(dtype=cfg.torch_dtype)
|
||||||
|
except ValueError as e:
|
||||||
|
LOG.warning("Failed to convert model to dtype %s", cfg.torch_dtype)
|
||||||
|
LOG.warning("Ignore this if the base_model is pre-quantized.")
|
||||||
|
LOG.warning("Error raised: %s", e)
|
||||||
|
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|||||||
@@ -151,12 +151,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
"Please make sure to point to a GPTQ model."
|
"Please make sure to point to a GPTQ model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||||
if (
|
if (
|
||||||
cfg.adapter
|
cfg.adapter
|
||||||
|
|||||||
Reference in New Issue
Block a user