Compare commits

...

4 Commits

Author SHA1 Message Date
NanoCode012
a65dbe779f fix: suspected eval vram increased usage 2025-06-23 18:44:03 +07:00
Wing Lian
0494359c6c update trl to 0.18.2 (#2814) 2025-06-19 11:27:59 -04:00
NanoCode012
26c39e1ca7 fix(doc): address exitcode formatting to help search (#2809) [skip ci] 2025-06-19 11:19:52 -04:00
Dan Saunders
45adf1bfb9 get_logger use_environ fix (#2808)
* get_logger use_environ fix

* rethinking

* replacing old logger imports

* simplify

* fix boolean cond
2025-06-19 11:16:52 -04:00
18 changed files with 43 additions and 60 deletions

View File

@@ -9,11 +9,11 @@ description: Frequently asked questions
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd) > A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
**Q: Exitcode -9** **Q: exitcode: -9**
> A: This usually happens when you run out of system RAM. > A: This usually happens when you run out of system RAM.
**Q: Exitcode -7 while using deepspeed** **Q: exitcode: -7 while using deepspeed**
> A: Try upgrading deepspeed w: `pip install -U deepspeed` > A: Try upgrading deepspeed w: `pip install -U deepspeed`

View File

@@ -18,7 +18,7 @@ tokenizers>=0.21.1
accelerate==1.7.0 accelerate==1.7.0
datasets==3.6.0 datasets==3.6.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.18.1 trl==0.18.2
hf_xet==1.1.2 hf_xet==1.1.2
optimum==1.16.2 optimum==1.16.2

View File

@@ -26,7 +26,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:

View File

@@ -215,10 +215,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.auto_find_batch_size self.cfg.auto_find_batch_size
) )
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["load_best_model_at_end"] = ( training_arguments_kwargs["load_best_model_at_end"] = (
( (
self.cfg.load_best_model_at_end is not False self.cfg.load_best_model_at_end is not False

View File

@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta

View File

@@ -28,7 +28,7 @@ from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "

View File

@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable from .utils import patch_with_compile_disable
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):

View File

@@ -15,6 +15,7 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator

View File

@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
{"additional_special_tokens": additional_special_tokens} {"additional_special_tokens": additional_special_tokens}
) )
if is_main_process(use_environ=True): if is_main_process():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")

View File

@@ -13,9 +13,9 @@ import inspect
import accelerate import accelerate
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -4,12 +4,12 @@ import inspect
import types import types
import torch import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -21,7 +21,7 @@ from axolotl.utils.schemas.config import (
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
def choose_device(cfg): def choose_device(cfg):

View File

@@ -1,6 +1,4 @@
""" """Utilities for distributed functionality."""
utility helpers for distributed checks
"""
import os import os
import pickle # nosec import pickle # nosec
@@ -19,7 +17,7 @@ from transformers.utils.import_utils import (
distributed_state = None # pylint: disable=invalid-name distributed_state = None # pylint: disable=invalid-name
def get_device_type(): def get_device_type() -> torch.device:
device = torch.device("cpu") device = torch.device("cpu")
if is_torch_cuda_available(): if is_torch_cuda_available():
device = torch.device("cuda") device = torch.device("cuda")
@@ -30,7 +28,7 @@ def get_device_type():
return device return device
def get_device_count(): def get_device_count() -> int:
cur_device = get_device_type() cur_device = get_device_type()
if "cuda" in str(cur_device): if "cuda" in str(cur_device):
return torch.cuda.device_count() return torch.cuda.device_count()
@@ -39,7 +37,7 @@ def get_device_count():
return 1 return 1
def get_current_device(): def get_current_device() -> int:
cur_device = get_device_type() cur_device = get_device_type()
if "cuda" in str(cur_device): if "cuda" in str(cur_device):
return torch.cuda.current_device() return torch.cuda.current_device()
@@ -48,12 +46,14 @@ def get_current_device():
return 0 return 0
def is_distributed(): def get_distributed_state() -> PartialState | None:
""" return distributed_state
Check if distributed training is initialized.
"""
def is_distributed() -> bool:
"""Check if distributed training is initialized."""
global distributed_state # pylint: disable=global-statement global distributed_state # pylint: disable=global-statement
if not distributed_state: if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout)) distributed_state = PartialState(timeout=timedelta(seconds=timeout))
@@ -69,31 +69,31 @@ def barrier():
dist.barrier() dist.barrier()
def is_main_process(use_environ=False): def is_main_process() -> bool:
""" """
Check if the current process is the main process. If not in distributed mode, Check if the current process is the main process. If not in distributed mode,
always return `True`. always return `True`.
Args: We use a simpler logic when the distributed state is not initialized: we just log
- use_environ (bool, optional): Use environment variable to determine main process. on the 0-th local rank.
Returns: Returns:
- bool: `True` if the current process is the main process, `False` otherwise. `True` if the current process is the main process, `False` otherwise.
""" """
if use_environ: if get_distributed_state() is None:
return os.environ.get("LOCAL_RANK", "0") == "0" return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed(): if not is_distributed():
return True return True
return dist.get_rank() == 0 return dist.get_rank() == 0
def is_local_main_process(use_environ=False): def is_local_main_process() -> bool:
if use_environ: if get_distributed_state() is None:
return os.environ.get("LOCAL_RANK", "0") == "0" return os.environ.get("LOCAL_RANK", "0") == "0"
return PartialState().is_local_main_process return PartialState().is_local_main_process
def get_world_size(): def get_world_size() -> int:
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
@@ -115,7 +115,7 @@ def cleanup_distributed():
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main: bool):
""" """
runs the wrapped context so that rank 0 runs first before other ranks runs the wrapped context so that rank 0 runs first before other ranks
""" """

View File

@@ -5,9 +5,8 @@ module to freeze/unfreeze parameters by name
import re import re
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
from accelerate.logging import get_logger
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -1,6 +1,4 @@
""" """Logging helpers to only log on main process."""
logging helpers to only log on main process
"""
import functools import functools
import logging import logging
@@ -14,27 +12,18 @@ from axolotl.utils.distributed import is_main_process
class MultiProcessAdapter(logging.LoggerAdapter): class MultiProcessAdapter(logging.LoggerAdapter):
""" """
logger adapter for distributed logging, specifically to only log on main process Logger adapter for distributed logging, specifically to only log on main process.
""" """
def __init__(self, logger, use_environ=False, extra=None):
super().__init__(logger, extra)
self.use_environ = use_environ
@staticmethod @staticmethod
def _should_log(main_process_only, use_environ=False): def _should_log(main_process_only: bool):
return not main_process_only or ( return not main_process_only or is_main_process()
main_process_only and is_main_process(use_environ=use_environ)
)
def log(self, level, msg, *args, **kwargs): def log(self, level, msg, *args, **kwargs):
use_environ = kwargs.pop("use_environ", self.use_environ)
main_process_only = kwargs.pop("main_process_only", True) main_process_only = kwargs.pop("main_process_only", True)
kwargs.setdefault("stacklevel", 2) kwargs.setdefault("stacklevel", 2)
if self.isEnabledFor(level) and self._should_log( if self.isEnabledFor(level) and self._should_log(main_process_only):
main_process_only, use_environ=use_environ
):
msg, kwargs = self.process(msg, kwargs) msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs) self.logger.log(level, msg, *args, **kwargs)
@@ -50,13 +39,11 @@ class MultiProcessAdapter(logging.LoggerAdapter):
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
def get_logger( def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
name: str, log_level: str | None = None, use_environ: bool = False
) -> MultiProcessAdapter:
if log_level is None: if log_level is None:
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
logger = logging.getLogger(name) logger = logging.getLogger(name)
if log_level is not None: if log_level is not None:
logger.setLevel(log_level.upper()) logger.setLevel(log_level.upper())
logger.root.setLevel(log_level.upper()) logger.root.setLevel(log_level.upper())
return MultiProcessAdapter(logger, use_environ=use_environ, extra={}) return MultiProcessAdapter(logger, extra={})

View File

@@ -48,7 +48,7 @@ from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.validation import ValidationMixin
from axolotl.utils.schemas.vllm import VllmConfig from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__)
class ModelInputConfig(BaseModel): class ModelInputConfig(BaseModel):

View File

@@ -11,7 +11,6 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -19,6 +18,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__) LOG = get_logger(__name__)