From a6bfbe34009686c59bbf5a198f8ad047019a78d6 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:32:51 +0530 Subject: [PATCH] torch_dtype -> dtype (#3177) * torch_dtype -> dtype * torch_dtype -> dtype --- src/axolotl/cli/delinearize_llama4.py | 4 +--- src/axolotl/cli/quantize.py | 2 +- src/axolotl/utils/model_shard_quant.py | 4 ++-- tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py | 2 +- tests/e2e/test_quantization.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/axolotl/cli/delinearize_llama4.py b/src/axolotl/cli/delinearize_llama4.py index 90227fccd..4f5448a14 100644 --- a/src/axolotl/cli/delinearize_llama4.py +++ b/src/axolotl/cli/delinearize_llama4.py @@ -85,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: unpatch_llama4 = patch_llama4_linearized_modeling() from transformers import Llama4ForConditionalGeneration - model_ = Llama4ForConditionalGeneration.from_pretrained( - model, torch_dtype=torch.bfloat16 - ) + model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16) processor = AutoProcessor.from_pretrained(model) processor.save_pretrained(output) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 6838f47d8..c11bcc6d9 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -69,7 +69,7 @@ def do_quantize( config = AutoConfig.from_pretrained(model_path) torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto", torch_dtype=torch_dtype + model_path, device_map="auto", dtype=torch_dtype ) LOG.info( diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index f20a9625e..ca152113a 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -148,7 +148,7 @@ def load_sharded_model( model = AutoModelForCausalLM.from_pretrained( model_name, use_cache=False, - torch_dtype=torch.float32, + dtype=torch.float32, _attn_implementation=model_config._attn_implementation, trust_remote_code=cfg.trust_remote_code, ) @@ -158,7 +158,7 @@ def load_sharded_model( with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, - torch_dtype=torch_dtype, + dtype=torch_dtype, trust_remote_code=cfg.trust_remote_code, ) return model diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 2180eb99d..73f883858 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -160,7 +160,7 @@ def test_geglu_model_integration(): """Test GeGLU activation with Gemma model.""" model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Gemma2ForCausalLM", - torch_dtype=torch.float16, + dtype=torch.float16, device_map="cuda:0", ) peft_config = get_peft_config( diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index b64aef51a..706279c6c 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -39,7 +39,7 @@ def model(): dummy_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2-0.5B", device_map="auto", - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, ) with torch.device(dummy_model.device): dummy_model.model.embed_tokens = torch.nn.Embedding(