temporary: inference validation script
This commit is contained in:
42
examples/mistral/mixtral_fused.py
Normal file
42
examples/mistral/mixtral_fused.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, MixtralSparseMoeBlock):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
setattr(model, name, smoe)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# Convert prompt to tokens
|
||||
prompt_template = "[INST] {prompt} [/INST]"
|
||||
|
||||
prompt = "You're standing on the surface of the Earth. "\
|
||||
"You walk one mile south, one mile west and one mile north. "\
|
||||
"You end up exactly where you started. Where are you?"
|
||||
|
||||
tokens = tokenizer(
|
||||
prompt_template.format(prompt=prompt),
|
||||
return_tensors='pt'
|
||||
).input_ids.cuda()
|
||||
|
||||
# Generate output
|
||||
generation_output = model.generate(
|
||||
tokens,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512
|
||||
)
|
||||
Reference in New Issue
Block a user