add torch_compile_mode options (#1763) [skip ci]

* add torch_compile_mode options

* make sure n_gpu is an int
This commit is contained in:
Wing Lian
2024-07-17 15:38:07 -04:00
committed by GitHub
parent 976f85195a
commit 8619b2d855
3 changed files with 17 additions and 1 deletions

View File

@@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
)

View File

@@ -1287,6 +1287,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
# DDP Config
if self.cfg.ddp_timeout:

View File

@@ -608,6 +608,9 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
@@ -1161,6 +1164,15 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""