diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index d05f08f9a..94ba83dd5 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -53,3 +53,12 @@ class LigerArgs(BaseModel): ) data["liger_glu_activation"] = data.pop("liger_swiglu") return data + + @model_validator(mode="before") + @classmethod + def check_tiled_mlp_conflict(cls, data): + if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True: + raise ValueError( + "You cannot have both `liger_glu_activation` and `tiled_mlp` set." + ) + return data diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 2f0ccbcbb..1889fa168 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -295,4 +295,9 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: LOG.info( "No Chat template selected. Consider adding a chat template for easier inference." ) + + # make the tokenizer.pad call quieter 🤐 + if hasattr(tokenizer, "deprecation_warnings"): + tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + return tokenizer diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index 3818c6b35..02bb3a8cb 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -7,6 +7,9 @@ import torch import torch.distributed as dist from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): @@ -63,6 +66,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): mlp_cls.forward = tiled_mlp_forward mlp_cls._compute_params = [] # pylint: disable=protected-access + LOG.info( + f"Successfully monkey-patched TiledMLP for model_type: {model_type}", + main_process_only=True, + ) except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import MLP class for model_type: {model_type}. " diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 292159bb8..64dbb2529 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-boolean-expressions import json -import logging import tempfile from pathlib import Path @@ -13,11 +12,12 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType # pylint: disable=too-many-lines -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} @@ -116,7 +116,8 @@ class DatasetValidationMixin: and not data.get("eval_table_size") ): LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" + "explicitly setting `eval_sample_packing` to match `sample_packing`", + main_process_only=True, ) data["eval_sample_packing"] = True