handle torch_compile set to auto (#2172) [skip ci]
* handle torch_compile set to auto * update docs [skip ci] * add tests
This commit is contained in:
@@ -337,7 +337,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
@@ -245,8 +245,8 @@ def validate_config(
|
||||
) = merge_input_args()
|
||||
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and not env_capabilities) or (
|
||||
env_capabilities and not capabilities
|
||||
if (capabilities and env_capabilities is None) or (
|
||||
env_capabilities and capabilities is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Both capabilities and env_capabilities must be provided or not provided."
|
||||
|
||||
@@ -741,7 +741,7 @@ class AxolotlInputConfig(
|
||||
special_tokens: Optional[SpecialTokensConfig] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
|
||||
torch_compile: Optional[bool] = None
|
||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
torch_compile_mode: Optional[
|
||||
Literal["default", "reduce-overhead", "max-autotune"]
|
||||
@@ -1582,3 +1582,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_auto(cls, data):
|
||||
if data.get("torch_compile") == "auto":
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
if env_capabilities.get("torch_version"):
|
||||
if version.parse(
|
||||
env_capabilities.get("torch_version")
|
||||
) >= version.parse("2.5.1"):
|
||||
LOG.info(
|
||||
"torch.compile is available, setting torch_compile to True"
|
||||
)
|
||||
data["torch_compile"] = True
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
return data
|
||||
|
||||
@@ -1196,6 +1196,46 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
|
||||
class TestTorchCompileValidation(BaseValidation):
|
||||
"""
|
||||
test suite for when torch_compile is set to 'auto'
|
||||
"""
|
||||
|
||||
def test_torch_compile_auto(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"torch_compile": "auto",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is True
|
||||
|
||||
env_capabilities = {"torch_version": "2.4.1"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is False
|
||||
|
||||
env_capabilities = {}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
assert updated_cfg.torch_compile is False
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
Test the validation for the config when the model config is available
|
||||
|
||||
Reference in New Issue
Block a user