Print axolotl art if train is called outside of cli: (#2627) [skip ci]

This commit is contained in:
Wing Lian
2025-05-06 11:18:45 -04:00
parent 5d61169f7c
commit 8c0303aa5e
3 changed files with 26 additions and 9 deletions

View File

@@ -16,8 +16,15 @@ AXOLOTL_LOGO = """
@@@@ @@@@@@@@@@@@@@@@ @@@@ @@@@@@@@@@@@@@@@
""" """
HAS_PRINTED_LOGO = False
def print_axolotl_text_art(): def print_axolotl_text_art():
"""Prints axolotl ASCII art.""" """Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process(): if is_main_process():
HAS_PRINTED_LOGO = True
print(AXOLOTL_LOGO) print(AXOLOTL_LOGO)

View File

@@ -48,6 +48,7 @@ def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets, calling Loads one or more training or evaluation datasets, calling
@@ -56,6 +57,7 @@ def load_datasets(
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments. cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
Returns: Returns:
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
@@ -77,20 +79,25 @@ def load_datasets(
preprocess_iterable=preprocess_iterable, preprocess_iterable=preprocess_iterable,
) )
if cli_args and ( if ( # pylint: disable=too-many-boolean-expressions
cli_args.debug cli_args
or cfg.debug and (
or cli_args.debug_text_only cli_args.debug
or int(cli_args.debug_num_examples) > 0 or cfg.debug
): or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels( check_dataset_labels(
train_samples, train_samples,
tokenizer, tokenizer,
num_examples=cli_args.debug_num_examples, num_examples=num_examples,
text_only=cli_args.debug_text_only, text_only=text_only,
) )
LOG.info("printing prompters...") LOG.info("printing prompters...")

View File

@@ -21,6 +21,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
@@ -501,6 +502,8 @@ def train(
Returns: Returns:
Tuple of (model, tokenizer) after training Tuple of (model, tokenizer) after training
""" """
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc. # Setup model, tokenizer, (causal or RLHF) trainer, etc.
( (
trainer, trainer,