add torch_compile_mode options (#1763) [skip ci]
* add torch_compile_mode options * make sure n_gpu is an int
This commit is contained in:
@@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1287,6 +1287,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"torch_compile_backend"
|
"torch_compile_backend"
|
||||||
] = self.cfg.torch_compile_backend
|
] = self.cfg.torch_compile_backend
|
||||||
|
if self.cfg.torch_compile_mode:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"torch_compile_mode"
|
||||||
|
] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
|
|||||||
@@ -608,6 +608,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
torch_compile_mode: Optional[
|
||||||
|
Literal["default", "reduce-overhead", "max-autotune"]
|
||||||
|
] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
@@ -1161,6 +1164,15 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
if data.get("deepspeed") and data.get("torch_compile"):
|
||||||
|
raise ValueError(
|
||||||
|
"torch_compile should be set within your deepspeed config file"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
Reference in New Issue
Block a user