From b71c0e344730ba21383de4f05e0c7c671522ccaa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 May 2025 11:18:45 -0400 Subject: [PATCH] Print axolotl art if train is called outside of cli: (#2627) [skip ci] --- src/axolotl/cli/art.py | 7 +++++++ src/axolotl/common/datasets.py | 25 ++++++++++++++++--------- src/axolotl/train.py | 3 +++ 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py index 6ed22a52d..2051784e9 100644 --- a/src/axolotl/cli/art.py +++ b/src/axolotl/cli/art.py @@ -16,8 +16,15 @@ AXOLOTL_LOGO = """ @@@@ @@@@@@@@@@@@@@@@ """ +HAS_PRINTED_LOGO = False + def print_axolotl_text_art(): """Prints axolotl ASCII art.""" + + global HAS_PRINTED_LOGO # pylint: disable=global-statement + if HAS_PRINTED_LOGO: + return if is_main_process(): + HAS_PRINTED_LOGO = True print(AXOLOTL_LOGO) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 2ab405ef1..9dd62f0f7 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -48,6 +48,7 @@ def load_datasets( *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, + debug: bool = False, ) -> TrainDatasetMeta: """ Loads one or more training or evaluation datasets, calling @@ -56,6 +57,7 @@ def load_datasets( Args: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Command-specific CLI arguments. + debug: Whether to print out tokenization of sample Returns: Dataclass with fields for training and evaluation datasets and the computed @@ -77,20 +79,25 @@ def load_datasets( preprocess_iterable=preprocess_iterable, ) - if cli_args and ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ): + if ( # pylint: disable=too-many-boolean-expressions + cli_args + and ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 + ) + ) or debug: LOG.info("check_dataset_labels...") - train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + num_examples = cli_args.debug_num_examples if cli_args else 1 + text_only = cli_args.debug_text_only if cli_args else False + train_samples = sample_dataset(train_dataset, num_examples) check_dataset_labels( train_samples, tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, + num_examples=num_examples, + text_only=text_only, ) LOG.info("printing prompters...") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 30d26b706..a5fd5e2e0 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -21,6 +21,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer +from axolotl.cli.art import print_axolotl_text_art from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, @@ -516,6 +517,8 @@ def train( Returns: Tuple of (model, tokenizer) after training """ + print_axolotl_text_art() + # Setup model, tokenizer, (causal or RLHF) trainer, etc. ( trainer,