From 3dd9c3bf3f45df23110bf5bb84993f16c0a627d9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 3 May 2025 12:02:26 -0400 Subject: [PATCH] setup hf transfer too and fix auto bf16 when fp16 enabled (#2620) [skip ci] --- src/axolotl/cli/evaluate.py | 4 ++-- src/axolotl/cli/main.py | 6 ++++-- src/axolotl/cli/train.py | 4 ++-- src/axolotl/utils/__init__.py | 9 +++++++++ src/axolotl/utils/config/__init__.py | 2 +- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index a1859f315..e52da66b7 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate -from axolotl.utils import set_pytorch_cuda_alloc_conf +from axolotl.utils import patch_optimized_env from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: cli_args: CLI arguments. """ # Enable expandable segments for cuda allocation to improve VRAM usage - set_pytorch_cuda_alloc_conf() + patch_optimized_env() # pylint: disable=duplicate-code print_axolotl_text_art() diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index de4fb6cbe..601add709 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -29,7 +29,7 @@ from axolotl.cli.utils import ( filter_none_kwargs, ) from axolotl.integrations.lm_eval.cli import lm_eval -from axolotl.utils import set_pytorch_cuda_alloc_conf +from axolotl.utils import patch_optimized_env from axolotl.utils.schemas.config import AxolotlInputConfig @@ -55,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ + patch_optimized_env() + if cloud: from axolotl.cli.cloud import do_cli_preprocess @@ -100,7 +102,7 @@ def train( config options. """ # Enable expandable segments for cuda allocation to improve VRAM usage - set_pytorch_cuda_alloc_conf() + patch_optimized_env() if "use_ray" in kwargs and kwargs["use_ray"]: accelerate = False diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 4f258313d..9e90cede3 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train -from axolotl.utils import set_pytorch_cuda_alloc_conf +from axolotl.utils import patch_optimized_env from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault @@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): cli_args: Training-specific CLI arguments. """ # Enable expandable segments for cuda allocation to improve VRAM usage - set_pytorch_cuda_alloc_conf() + patch_optimized_env() print_axolotl_text_art() check_accelerate_default_config() diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index ffa528cc9..3d0ba7c9c 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf(): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( "expandable_segments:True,roundup_power2_divisions:16" ) + + +def patch_optimized_env(): + """ + Patch environment variables to improve VRAM usage and increase download speed + """ + if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + set_pytorch_cuda_alloc_conf() diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 35e742a89..0de87fa5c 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -59,7 +59,7 @@ def choose_device(cfg): def resolve_dtype(cfg): if ( - cfg.bf16 == "auto" and not cfg.use_ray + not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray ): # if we use ray we want to defer this check to the worker node if is_torch_bf16_gpu_available(): LOG.debug("bf16 support detected, enabling for this configuration.")