post merge lora fixes for CI (#3536) [skip ci]
* post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation
This commit is contained in:
@@ -1385,6 +1385,39 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||
if not has_grouped_mm:
|
||||
is_moe = False
|
||||
model_type = data.get("model_config_type", "")
|
||||
if model_type and "moe" in model_type.lower():
|
||||
is_moe = True
|
||||
if not is_moe:
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
base_model = data.get("base_model")
|
||||
if base_model:
|
||||
auto_cfg = AutoConfig.from_pretrained(
|
||||
base_model, trust_remote_code=False
|
||||
)
|
||||
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||
auto_cfg, "num_experts", None
|
||||
):
|
||||
is_moe = True
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if is_moe:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||
|
||||
Reference in New Issue
Block a user