make the initial call to tokenizer.pad not spam the console (#2946) [skip ci]
* make the initial call to tokenizer.pad not spam the console * add guard from feedback * make another common console output less verbose * more logging fixes
This commit is contained in:
@@ -53,3 +53,12 @@ class LigerArgs(BaseModel):
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tiled_mlp_conflict(cls, data):
|
||||
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -295,4 +295,9 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
|
||||
# make the tokenizer.pad call quieter 🤐
|
||||
if hasattr(tokenizer, "deprecation_warnings"):
|
||||
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -7,6 +7,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
@@ -63,6 +66,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
|
||||
mlp_cls.forward = tiled_mlp_forward
|
||||
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
||||
LOG.info(
|
||||
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
||||
main_process_only=True,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(
|
||||
f"Could not import MLP class for model_type: {model_type}. "
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,11 +12,12 @@ from pydantic import (
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
@@ -116,7 +116,8 @@ class DatasetValidationMixin:
|
||||
and not data.get("eval_table_size")
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`",
|
||||
main_process_only=True,
|
||||
)
|
||||
data["eval_sample_packing"] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user