fix: freeze base_model and register config into Auto class

This commit is contained in:
NanoCode012
2025-02-05 15:59:06 +07:00
parent 253dcdd0cf
commit 8294e6218f

View File

@@ -45,6 +45,10 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# freeze model
for p in model.parameters():
p.requires_grad = False
# load config
base_config = load_model_config(cfg)
@@ -56,6 +60,18 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
model, config=linear_llama_config, train_attention=True
)
# register model
from transformers import AutoConfig, AutoModel
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
# Get datasets
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train_dataset = dataset_meta.train_dataset
@@ -86,14 +102,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
# NOTE: If in peft mode, consider whether to auto-merge
# save model
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
safe_serialization = cfg.save_safetensors is True
# NOTE: may need to consider other ways of saving due to multi-gpu etc
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(save_path, safe_serialization=safe_serialization)
# cleanup
plugin_manager = PluginManager.get_instance()