handle torch_compile set to auto (#2172) [skip ci]

* handle torch_compile set to auto

* update docs [skip ci]

* add tests
This commit is contained in:
Wing Lian
2024-12-17 16:42:41 -05:00
committed by GitHub
parent 10cfecf02e
commit 3798229d85
4 changed files with 64 additions and 4 deletions

View File

@@ -245,8 +245,8 @@ def validate_config(
) = merge_input_args()
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."

View File

@@ -741,7 +741,7 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
torch_compile: Optional[bool] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
@@ -1582,3 +1582,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):
if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"):
if version.parse(
env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True
else:
data["torch_compile"] = False
else:
data["torch_compile"] = False
return data