handle torch_compile set to auto (#2172) [skip ci]
* handle torch_compile set to auto * update docs [skip ci] * add tests
This commit is contained in:
@@ -337,7 +337,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see
|
|||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
# Whether to use torch.compile and which backend to use
|
# Whether to use torch.compile and which backend to use
|
||||||
torch_compile: # bool
|
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||||
|
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||||
torch_compile_backend: # Optional[str]
|
torch_compile_backend: # Optional[str]
|
||||||
|
|
||||||
# Training hyperparameters
|
# Training hyperparameters
|
||||||
|
|||||||
@@ -245,8 +245,8 @@ def validate_config(
|
|||||||
) = merge_input_args()
|
) = merge_input_args()
|
||||||
|
|
||||||
if capabilities or env_capabilities:
|
if capabilities or env_capabilities:
|
||||||
if (capabilities and not env_capabilities) or (
|
if (capabilities and env_capabilities is None) or (
|
||||||
env_capabilities and not capabilities
|
env_capabilities and capabilities is None
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Both capabilities and env_capabilities must be provided or not provided."
|
"Both capabilities and env_capabilities must be provided or not provided."
|
||||||
|
|||||||
@@ -741,7 +741,7 @@ class AxolotlInputConfig(
|
|||||||
special_tokens: Optional[SpecialTokensConfig] = None
|
special_tokens: Optional[SpecialTokensConfig] = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: Optional[List[str]] = None
|
||||||
|
|
||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
torch_compile_mode: Optional[
|
torch_compile_mode: Optional[
|
||||||
Literal["default", "reduce-overhead", "max-autotune"]
|
Literal["default", "reduce-overhead", "max-autotune"]
|
||||||
@@ -1582,3 +1582,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_torch_compile_auto(cls, data):
|
||||||
|
if data.get("torch_compile") == "auto":
|
||||||
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
|
if env_capabilities.get("torch_version"):
|
||||||
|
if version.parse(
|
||||||
|
env_capabilities.get("torch_version")
|
||||||
|
) >= version.parse("2.5.1"):
|
||||||
|
LOG.info(
|
||||||
|
"torch.compile is available, setting torch_compile to True"
|
||||||
|
)
|
||||||
|
data["torch_compile"] = True
|
||||||
|
else:
|
||||||
|
data["torch_compile"] = False
|
||||||
|
else:
|
||||||
|
data["torch_compile"] = False
|
||||||
|
return data
|
||||||
|
|||||||
@@ -1196,6 +1196,46 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchCompileValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
test suite for when torch_compile is set to 'auto'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_torch_compile_auto(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"torch_compile": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.5.1"}
|
||||||
|
capabilities = {"bf16": True}
|
||||||
|
updated_cfg = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_cfg.torch_compile is True
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.4.1"}
|
||||||
|
capabilities = {"bf16": True}
|
||||||
|
updated_cfg = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_cfg.torch_compile is False
|
||||||
|
|
||||||
|
env_capabilities = {}
|
||||||
|
capabilities = {"bf16": True}
|
||||||
|
updated_cfg = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_cfg.torch_compile is False
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
Test the validation for the config when the model config is available
|
Test the validation for the config when the model config is available
|
||||||
|
|||||||
Reference in New Issue
Block a user