make the initial call to tokenizer.pad not spam the console (#2946) [skip ci]

* make the initial call to tokenizer.pad not spam the console

* add guard from feedback

* make another common console output less verbose

* more logging fixes
This commit is contained in:
Wing Lian
2025-07-19 13:53:35 -04:00
committed by GitHub
parent 170322a1f0
commit 109d9c7442
4 changed files with 25 additions and 3 deletions

View File

@@ -53,3 +53,12 @@ class LigerArgs(BaseModel):
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
)
return data

View File

@@ -295,4 +295,9 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
# make the tokenizer.pad call quieter 🤐
if hasattr(tokenizer, "deprecation_warnings"):
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
return tokenizer

View File

@@ -7,6 +7,9 @@ import torch
import torch.distributed as dist
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
@@ -63,6 +66,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,
)
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-boolean-expressions
import json
import logging
import tempfile
from pathlib import Path
@@ -13,11 +12,12 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
# pylint: disable=too-many-lines
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -116,7 +116,8 @@ class DatasetValidationMixin:
and not data.get("eval_table_size")
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
"explicitly setting `eval_sample_packing` to match `sample_packing`",
main_process_only=True,
)
data["eval_sample_packing"] = True