Merge branch 'main' into uv-first

This commit is contained in:
Dan Saunders
2025-10-04 09:07:22 -04:00
6 changed files with 9 additions and 13 deletions

View File

@@ -84,7 +84,7 @@ jobs:
uv pip show --system torch uv pip show --system torch
uv pip install --system wheel uv pip install --system wheel
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
uv pip install --system --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt uv pip install --system --no-cache-dir --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt
set -o pipefail set -o pipefail
python scripts/unsloth_install.py | bash python scripts/unsloth_install.py | bash
python scripts/cutcrossentropy_install.py | bash python scripts/cutcrossentropy_install.py | bash
@@ -155,12 +155,10 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv pip show --system torch uv pip show --system torch
uv pip install --system wheel uv pip install --system wheel build
uv pip install --system build
python -m build --sdist python -m build --sdist
uv pip install --system dist/*.tar.gz
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
uv pip install --system ".[dev]" --constraints torch-constraints.txt uv pip install --no-cache-dir --no-build-isolation --system "dist/axolotl*.tar.gz[dev]" --constraints torch-constraints.txt
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -85,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
unpatch_llama4 = patch_llama4_linearized_modeling() unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained( model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model) processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output) processor.save_pretrained(output)

View File

@@ -69,7 +69,7 @@ def do_quantize(
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype model_path, device_map="auto", dtype=torch_dtype
) )
LOG.info( LOG.info(

View File

@@ -148,7 +148,7 @@ def load_sharded_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
use_cache=False, use_cache=False,
torch_dtype=torch.float32, dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, _attn_implementation=model_config._attn_implementation,
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
@@ -158,7 +158,7 @@ def load_sharded_model(
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
model_config, model_config,
torch_dtype=torch_dtype, dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
return model return model

View File

@@ -160,7 +160,7 @@ def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model.""" """Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Gemma2ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16, dtype=torch.float16,
device_map="cuda:0", device_map="cuda:0",
) )
peft_config = get_peft_config( peft_config = get_peft_config(

View File

@@ -39,7 +39,7 @@ def model():
dummy_model = AutoModelForCausalLM.from_pretrained( dummy_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B", "Qwen/Qwen2-0.5B",
device_map="auto", device_map="auto",
torch_dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
with torch.device(dummy_model.device): with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding( dummy_model.model.embed_tokens = torch.nn.Embedding(