save tokenizer too
This commit is contained in:
@@ -11,7 +11,7 @@ import torch
|
|||||||
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||||
@@ -277,6 +277,7 @@ def convert_llama_to_rrt(
|
|||||||
]
|
]
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
if num_hidden_layers % recurse_layers != 0:
|
if num_hidden_layers % recurse_layers != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -293,6 +294,7 @@ def convert_llama_to_rrt(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||||
|
|
||||||
# create a new state_dict to store the RRT model weights
|
# create a new state_dict to store the RRT model weights
|
||||||
|
|||||||
Reference in New Issue
Block a user