save tokenizer too

This commit is contained in:
Wing Lian
2025-01-20 13:31:26 -05:00
parent 474ba1a1b8
commit b582d340b0

View File

@@ -11,7 +11,7 @@ import torch
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
@@ -277,6 +277,7 @@ def convert_llama_to_rrt(
]
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers % recurse_layers != 0:
raise ValueError(
@@ -293,6 +294,7 @@ def convert_llama_to_rrt(
}
)
config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
# create a new state_dict to store the RRT model weights