fix: freeze base_model and register config into Auto class
This commit is contained in:
@@ -45,6 +45,10 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# freeze model
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# load config
|
||||
base_config = load_model_config(cfg)
|
||||
|
||||
@@ -56,6 +60,18 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# register model
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
# set save_path, save tokenizer and model config.
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
# Get datasets
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -86,14 +102,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
# save model
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||
|
||||
# cleanup
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
Reference in New Issue
Block a user