replace zero_only with simpler if statement (#2592)

This commit is contained in:
Wing Lian
2025-04-30 13:11:03 -04:00
committed by GitHub
parent 89ca14d9a0
commit 5e949eaa07
6 changed files with 22 additions and 19 deletions

View File

@@ -9,6 +9,7 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py' - 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -27,6 +27,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
TRANSFORMERS_IS_CI: "yes"
jobs: jobs:
pre-commit: pre-commit:
name: pre-commit name: pre-commit

View File

@@ -25,7 +25,7 @@ import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only from axolotl.utils.distributed import is_main_process
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -76,7 +76,7 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch, cce_patch,
) )
with zero_only(): if is_main_process(use_environ=True):
LOG.info( LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
) )

View File

@@ -23,8 +23,8 @@ import logging
import sys import sys
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable from .utils import patch_with_compile_disable
@@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters: elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only(): if is_main_process(use_environ=True):
LOG.info( LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
) )

View File

@@ -69,17 +69,27 @@ def barrier():
dist.barrier() dist.barrier()
def is_main_process(): def is_main_process(use_environ=False):
""" """
Check if the current process is the main process. If not in distributed mode, Check if the current process is the main process. If not in distributed mode,
always return `True`. always return `True`.
Args:
- use_environ (bool, optional): Use environment variable to determine main process.
Returns:
- bool: `True` if the current process is the main process, `False` otherwise.
""" """
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed(): if not is_distributed():
return True return True
return dist.get_rank() == 0 return dist.get_rank() == 0
def is_local_main_process(): def is_local_main_process(use_environ=False):
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
return PartialState().is_local_main_process return PartialState().is_local_main_process
@@ -99,17 +109,6 @@ def cleanup_distributed():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main):
""" """

View File

@@ -68,7 +68,7 @@ from axolotl.utils.distributed import (
get_device_count, get_device_count,
get_device_type, get_device_type,
is_local_main_process, is_local_main_process,
zero_only, is_main_process,
) )
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
@@ -437,7 +437,7 @@ def load_tokenizer(cfg):
{"additional_special_tokens": additional_special_tokens} {"additional_special_tokens": additional_special_tokens}
) )
with zero_only(): if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")