Compare commits
4 Commits
codecov-pu
...
chore/docs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
159f0531f9 | ||
|
|
0494359c6c | ||
|
|
26c39e1ca7 | ||
|
|
45adf1bfb9 |
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
||||
|
||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||
|
||||
**Q: Exitcode -9**
|
||||
**Q: exitcode: -9**
|
||||
|
||||
> A: This usually happens when you run out of system RAM.
|
||||
|
||||
**Q: Exitcode -7 while using deepspeed**
|
||||
**Q: exitcode: -7 while using deepspeed**
|
||||
|
||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.1
|
||||
accelerate==1.7.0
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
trl==0.18.2
|
||||
hf_xet==1.1.2
|
||||
|
||||
optimum==1.16.2
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
@@ -28,7 +28,7 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
|
||||
@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
if is_main_process():
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
|
||||
@@ -13,9 +13,9 @@ import inspect
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -4,12 +4,12 @@ import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from axolotl.utils.schemas.config import (
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
"""Utilities for distributed functionality."""
|
||||
|
||||
import os
|
||||
import pickle # nosec
|
||||
@@ -19,7 +17,7 @@ from transformers.utils.import_utils import (
|
||||
distributed_state = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_device_type():
|
||||
def get_device_type() -> torch.device:
|
||||
device = torch.device("cpu")
|
||||
if is_torch_cuda_available():
|
||||
device = torch.device("cuda")
|
||||
@@ -30,7 +28,7 @@ def get_device_type():
|
||||
return device
|
||||
|
||||
|
||||
def get_device_count():
|
||||
def get_device_count() -> int:
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.device_count()
|
||||
@@ -39,7 +37,7 @@ def get_device_count():
|
||||
return 1
|
||||
|
||||
|
||||
def get_current_device():
|
||||
def get_current_device() -> int:
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.current_device()
|
||||
@@ -48,12 +46,14 @@ def get_current_device():
|
||||
return 0
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
"""
|
||||
def get_distributed_state() -> PartialState | None:
|
||||
return distributed_state
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Check if distributed training is initialized."""
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if not distributed_state:
|
||||
if distributed_state is None:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
|
||||
@@ -69,31 +69,28 @@ def barrier():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def is_main_process(use_environ=False):
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check if the current process is the main process. If not in distributed mode,
|
||||
always return `True`.
|
||||
|
||||
Args:
|
||||
- use_environ (bool, optional): Use environment variable to determine main process.
|
||||
|
||||
Returns:
|
||||
- bool: `True` if the current process is the main process, `False` otherwise.
|
||||
`True` if the current process is the main process, `False` otherwise.
|
||||
"""
|
||||
if use_environ:
|
||||
if get_distributed_state() is None:
|
||||
return os.environ.get("LOCAL_RANK", "0") == "0"
|
||||
if not is_distributed():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def is_local_main_process(use_environ=False):
|
||||
if use_environ:
|
||||
def is_local_main_process() -> bool:
|
||||
if get_distributed_state() is None:
|
||||
return os.environ.get("LOCAL_RANK", "0") == "0"
|
||||
return PartialState().is_local_main_process
|
||||
|
||||
|
||||
def get_world_size():
|
||||
def get_world_size() -> int:
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@@ -115,7 +112,7 @@ def cleanup_distributed():
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
def zero_first(is_main: bool):
|
||||
"""
|
||||
runs the wrapped context so that rank 0 runs first before other ranks
|
||||
"""
|
||||
|
||||
@@ -5,9 +5,8 @@ module to freeze/unfreeze parameters by name
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
logging helpers to only log on main process
|
||||
"""
|
||||
"""Logging helpers to only log on main process."""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
@@ -14,27 +12,18 @@ from axolotl.utils.distributed import is_main_process
|
||||
|
||||
class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
"""
|
||||
logger adapter for distributed logging, specifically to only log on main process
|
||||
Logger adapter for distributed logging, specifically to only log on main process.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, use_environ=False, extra=None):
|
||||
super().__init__(logger, extra)
|
||||
self.use_environ = use_environ
|
||||
|
||||
@staticmethod
|
||||
def _should_log(main_process_only, use_environ=False):
|
||||
return not main_process_only or (
|
||||
main_process_only and is_main_process(use_environ=use_environ)
|
||||
)
|
||||
def _should_log(main_process_only: bool):
|
||||
return not main_process_only or is_main_process()
|
||||
|
||||
def log(self, level, msg, *args, **kwargs):
|
||||
use_environ = kwargs.pop("use_environ", self.use_environ)
|
||||
main_process_only = kwargs.pop("main_process_only", True)
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
|
||||
if self.isEnabledFor(level) and self._should_log(
|
||||
main_process_only, use_environ=use_environ
|
||||
):
|
||||
if self.isEnabledFor(level) and self._should_log(main_process_only):
|
||||
msg, kwargs = self.process(msg, kwargs)
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
@@ -50,13 +39,11 @@ class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(
|
||||
name: str, log_level: str | None = None, use_environ: bool = False
|
||||
) -> MultiProcessAdapter:
|
||||
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
||||
if log_level is None:
|
||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||
logger = logging.getLogger(name)
|
||||
if log_level is not None:
|
||||
logger.setLevel(log_level.upper())
|
||||
logger.root.setLevel(log_level.upper())
|
||||
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})
|
||||
return MultiProcessAdapter(logger, extra={})
|
||||
|
||||
@@ -48,7 +48,7 @@ from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.validation import ValidationMixin
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class ModelInputConfig(BaseModel):
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -19,6 +18,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user