handle torch_compile set to auto (#2172) [skip ci]
* handle torch_compile set to auto * update docs [skip ci] * add tests
This commit is contained in:
@@ -1196,6 +1196,46 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
|
||||
class TestTorchCompileValidation(BaseValidation):
|
||||
"""
|
||||
test suite for when torch_compile is set to 'auto'
|
||||
"""
|
||||
|
||||
def test_torch_compile_auto(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"torch_compile": "auto",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is True
|
||||
|
||||
env_capabilities = {"torch_version": "2.4.1"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is False
|
||||
|
||||
env_capabilities = {}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is False
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
Test the validation for the config when the model config is available
|
||||
|
||||
Reference in New Issue
Block a user