save tokenizer too
This commit is contained in:
@@ -11,7 +11,7 @@ import torch
|
||||
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||
@@ -277,6 +277,7 @@ def convert_llama_to_rrt(
|
||||
]
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if num_hidden_layers % recurse_layers != 0:
|
||||
raise ValueError(
|
||||
@@ -293,6 +294,7 @@ def convert_llama_to_rrt(
|
||||
}
|
||||
)
|
||||
config.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||
|
||||
# create a new state_dict to store the RRT model weights
|
||||
|
||||
Reference in New Issue
Block a user