diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index edc1c7d79..ffb3577ea 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -9,6 +9,7 @@ on: - 'pyproject.toml' - '.github/workflows/multi-gpu-e2e.yml' - 'src/axolotl/core/trainers/mixins/sequence_parallel.py' + - 'src/axolotl/utils/distributed.py' workflow_dispatch: schedule: - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e54850af7..e633215e4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,6 +27,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +env: + TRANSFORMERS_IS_CI: "yes" + jobs: pre-commit: name: pre-commit diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 4792adb6f..7420674fa 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -25,7 +25,7 @@ import torch from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version -from axolotl.utils.distributed import zero_only +from axolotl.utils.distributed import is_main_process from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 @@ -76,7 +76,7 @@ class CutCrossEntropyPlugin(BasePlugin): cce_patch, ) - with zero_only(): + if is_main_process(use_environ=True): LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" ) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 4e8d00552..50df7a408 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -23,8 +23,8 @@ import logging import sys from axolotl.integrations.base import BasePlugin +from axolotl.utils.distributed import is_main_process -from ...utils.distributed import zero_only from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .utils import patch_with_compile_disable @@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation - with zero_only(): + if is_main_process(use_environ=True): LOG.info( f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" ) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index acd2566d0..8c52102c8 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -69,17 +69,27 @@ def barrier(): dist.barrier() -def is_main_process(): +def is_main_process(use_environ=False): """ Check if the current process is the main process. If not in distributed mode, always return `True`. + + Args: + - use_environ (bool, optional): Use environment variable to determine main process. + + Returns: + - bool: `True` if the current process is the main process, `False` otherwise. """ + if use_environ: + return os.environ.get("LOCAL_RANK", "0") == "0" if not is_distributed(): return True return dist.get_rank() == 0 -def is_local_main_process(): +def is_local_main_process(use_environ=False): + if use_environ: + return os.environ.get("LOCAL_RANK", "0") == "0" return PartialState().is_local_main_process @@ -99,17 +109,6 @@ def cleanup_distributed(): torch.distributed.destroy_process_group() -@contextmanager -def zero_only(): - """ - Context manager that only runs the enclosed block on the main rank. - """ - if is_main_process(): - yield - else: - yield None - - @contextmanager def zero_first(is_main): """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ab4cc19bb..e88de1bad 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -68,7 +68,7 @@ from axolotl.utils.distributed import ( get_device_count, get_device_type, is_local_main_process, - zero_only, + is_main_process, ) from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers @@ -437,7 +437,7 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - with zero_only(): + if is_main_process(use_environ=True): LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")