replace zero_only with simpler if statement (#2592)

This commit is contained in:
Wing Lian
2025-04-30 13:11:03 -04:00
committed by GitHub
parent 89ca14d9a0
commit 5e949eaa07
6 changed files with 22 additions and 19 deletions

View File

@@ -9,6 +9,7 @@ on:
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -27,6 +27,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
TRANSFORMERS_IS_CI: "yes"
jobs:
pre-commit:
name: pre-commit

View File

@@ -25,7 +25,7 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import is_main_process
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -76,7 +76,7 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)

View File

@@ -23,8 +23,8 @@ import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
@@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)

View File

@@ -69,17 +69,27 @@ def barrier():
dist.barrier()
def is_main_process():
def is_main_process(use_environ=False):
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Args:
- use_environ (bool, optional): Use environment variable to determine main process.
Returns:
- bool: `True` if the current process is the main process, `False` otherwise.
"""
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed():
return True
return dist.get_rank() == 0
def is_local_main_process():
def is_local_main_process(use_environ=False):
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
return PartialState().is_local_main_process
@@ -99,17 +109,6 @@ def cleanup_distributed():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager
def zero_first(is_main):
"""

View File

@@ -68,7 +68,7 @@ from axolotl.utils.distributed import (
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
@@ -437,7 +437,7 @@ def load_tokenizer(cfg):
{"additional_special_tokens": additional_special_tokens}
)
with zero_only():
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")