torch_dtype -> dtype (#3177)
* torch_dtype -> dtype * torch_dtype -> dtype
This commit is contained in:
@@ -85,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def do_quantize(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
model_path, device_map="auto", dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
|
||||
@@ -148,7 +148,7 @@ def load_sharded_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.float32,
|
||||
dtype=torch.float32,
|
||||
_attn_implementation=model_config._attn_implementation,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
@@ -158,7 +158,7 @@ def load_sharded_model(
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch_dtype,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user