Set the pytorch_cuda_alloc_conf env in the train module (#2447)
This commit is contained in:
@@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user