diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index d55448da4..0fb13bfd3 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -26,7 +26,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 9162bc745..7d9b6a6f9 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) if TYPE_CHECKING: from axolotl.common.datasets import TrainDatasetMeta diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index a7e94e363..9b155ca8a 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -28,7 +28,7 @@ from axolotl.utils.logging import get_logger from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install cut_cross_entropy with transformers support using " diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 1c17ab2b5..8de94c78b 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .utils import patch_with_compile_disable -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) class LigerPlugin(BasePlugin): diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 7c9eb23d5..d05f08f9a 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -15,6 +15,7 @@ """ Module for handling LIGER input arguments. """ + from typing import Optional from pydantic import BaseModel, model_validator diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 4f9a60a69..9fdb7d5cc 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: {"additional_special_tokens": additional_special_tokens} ) - if is_main_process(use_environ=True): + if is_main_process(): LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 7d733cfc1..d83476e5a 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -13,9 +13,9 @@ import inspect import accelerate import torch import torch.distributed as dist -from accelerate.logging import get_logger from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RingAttnFunc LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 61f4eeea0..146047e95 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -4,12 +4,12 @@ import inspect import types import torch -from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index e0eaf9ac9..745a0c8ce 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -21,7 +21,7 @@ from axolotl.utils.schemas.config import ( from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) def choose_device(cfg): diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 0673c6e95..b509ad0ca 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,6 +1,4 @@ -""" -utility helpers for distributed checks -""" +"""Utilities for distributed functionality.""" import os import pickle # nosec @@ -19,7 +17,7 @@ from transformers.utils.import_utils import ( distributed_state = None # pylint: disable=invalid-name -def get_device_type(): +def get_device_type() -> torch.device: device = torch.device("cpu") if is_torch_cuda_available(): device = torch.device("cuda") @@ -30,7 +28,7 @@ def get_device_type(): return device -def get_device_count(): +def get_device_count() -> int: cur_device = get_device_type() if "cuda" in str(cur_device): return torch.cuda.device_count() @@ -39,7 +37,7 @@ def get_device_count(): return 1 -def get_current_device(): +def get_current_device() -> int: cur_device = get_device_type() if "cuda" in str(cur_device): return torch.cuda.current_device() @@ -48,12 +46,14 @@ def get_current_device(): return 0 -def is_distributed(): - """ - Check if distributed training is initialized. - """ +def get_distributed_state() -> PartialState | None: + return distributed_state + + +def is_distributed() -> bool: + """Check if distributed training is initialized.""" global distributed_state # pylint: disable=global-statement - if not distributed_state: + if distributed_state is None: timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) distributed_state = PartialState(timeout=timedelta(seconds=timeout)) @@ -69,31 +69,31 @@ def barrier(): dist.barrier() -def is_main_process(use_environ=False): +def is_main_process() -> bool: """ Check if the current process is the main process. If not in distributed mode, always return `True`. - Args: - - use_environ (bool, optional): Use environment variable to determine main process. + We use a simpler logic when the distributed state is not initialized: we just log + on the 0-th local rank. Returns: - - bool: `True` if the current process is the main process, `False` otherwise. + `True` if the current process is the main process, `False` otherwise. """ - if use_environ: + if get_distributed_state() is None: return os.environ.get("LOCAL_RANK", "0") == "0" if not is_distributed(): return True return dist.get_rank() == 0 -def is_local_main_process(use_environ=False): - if use_environ: +def is_local_main_process() -> bool: + if get_distributed_state() is None: return os.environ.get("LOCAL_RANK", "0") == "0" return PartialState().is_local_main_process -def get_world_size(): +def get_world_size() -> int: return int(os.getenv("WORLD_SIZE", "1")) @@ -115,7 +115,7 @@ def cleanup_distributed(): @contextmanager -def zero_first(is_main): +def zero_first(is_main: bool): """ runs the wrapped context so that rank 0 runs first before other ranks """ diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 65ca62137..936708f04 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -5,9 +5,8 @@ module to freeze/unfreeze parameters by name import re from typing import Callable, List, Tuple, Union -from accelerate.logging import get_logger - from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 80daab4ea..7cc3530ae 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -1,6 +1,4 @@ -""" -logging helpers to only log on main process -""" +"""Logging helpers to only log on main process.""" import functools import logging @@ -14,27 +12,18 @@ from axolotl.utils.distributed import is_main_process class MultiProcessAdapter(logging.LoggerAdapter): """ - logger adapter for distributed logging, specifically to only log on main process + Logger adapter for distributed logging, specifically to only log on main process. """ - def __init__(self, logger, use_environ=False, extra=None): - super().__init__(logger, extra) - self.use_environ = use_environ - @staticmethod - def _should_log(main_process_only, use_environ=False): - return not main_process_only or ( - main_process_only and is_main_process(use_environ=use_environ) - ) + def _should_log(main_process_only: bool): + return not main_process_only or is_main_process() def log(self, level, msg, *args, **kwargs): - use_environ = kwargs.pop("use_environ", self.use_environ) main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) - if self.isEnabledFor(level) and self._should_log( - main_process_only, use_environ=use_environ - ): + if self.isEnabledFor(level) and self._should_log(main_process_only): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) @@ -50,13 +39,11 @@ class MultiProcessAdapter(logging.LoggerAdapter): self.warning(*args, **kwargs) -def get_logger( - name: str, log_level: str | None = None, use_environ: bool = False -) -> MultiProcessAdapter: +def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter: if log_level is None: log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level.upper()) logger.root.setLevel(log_level.upper()) - return MultiProcessAdapter(logger, use_environ=use_environ, extra={}) + return MultiProcessAdapter(logger, extra={}) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 259daa56f..460043272 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -48,7 +48,7 @@ from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) # pylint: disable=too-many-ancestors diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 6f995996d..5eea11444 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator from axolotl.utils.logging import get_logger -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) class ModelInputConfig(BaseModel): diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 33ddadf78..e996cd62b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,7 +11,6 @@ from typing import List, Optional import numpy as np import torch import torch.cuda -from accelerate.logging import get_logger from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available @@ -19,6 +18,7 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__)