post merge lora fixes for CI (#3536) [skip ci]

* post merge lora fixes for CI

* handle lora kernel auto-enable for moe without grouped_mm

* prefer not to import torch in schema validation
This commit is contained in:
Wing Lian
2026-03-23 02:26:10 -04:00
committed by GitHub
parent 0e583efeaa
commit 86be9f329e
2 changed files with 67 additions and 3 deletions

View File

@@ -1385,6 +1385,39 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip auto-enable for MoE models when native grouped_mm is unavailable
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
# with out= which bypasses autocast and fails on mixed dtypes during eval.
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
if not has_grouped_mm:
is_moe = False
model_type = data.get("model_config_type", "")
if model_type and "moe" in model_type.lower():
is_moe = True
if not is_moe:
try:
from transformers import AutoConfig
base_model = data.get("base_model")
if base_model:
auto_cfg = AutoConfig.from_pretrained(
base_model, trust_remote_code=False
)
if getattr(auto_cfg, "num_local_experts", None) or getattr(
auto_cfg, "num_experts", None
):
is_moe = True
except Exception: # pylint: disable=broad-exception-caught
pass
if is_moe:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1