handle torch_compile set to auto (#2172) [skip ci]

* handle torch_compile set to auto

* update docs [skip ci]

* add tests
This commit is contained in:
Wing Lian
2024-12-17 16:42:41 -05:00
committed by GitHub
parent 10cfecf02e
commit 3798229d85
4 changed files with 64 additions and 4 deletions

View File

@@ -337,7 +337,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see
output_dir: ./completed-model
# Whether to use torch.compile and which backend to use
torch_compile: # bool
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]
# Training hyperparameters

View File

@@ -245,8 +245,8 @@ def validate_config(
) = merge_input_args()
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."

View File

@@ -741,7 +741,7 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
torch_compile: Optional[bool] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
@@ -1582,3 +1582,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):
if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"):
if version.parse(
env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True
else:
data["torch_compile"] = False
else:
data["torch_compile"] = False
return data

View File

@@ -1196,6 +1196,46 @@ class TestValidation(BaseValidation):
)
class TestTorchCompileValidation(BaseValidation):
"""
test suite for when torch_compile is set to 'auto'
"""
def test_torch_compile_auto(self, minimal_cfg):
cfg = (
DictDefault(
{
"torch_compile": "auto",
}
)
| minimal_cfg
)
env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
assert updated_cfg.torch_compile is True
env_capabilities = {"torch_version": "2.4.1"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
assert updated_cfg.torch_compile is False
env_capabilities = {}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
assert updated_cfg.torch_compile is False
class TestValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available