Print axolotl art if train is called outside of cli: (#2627) [skip ci]

This commit is contained in:
Wing Lian
2025-05-06 11:18:45 -04:00
committed by GitHub
parent ddaebf8309
commit b71c0e3447
3 changed files with 26 additions and 9 deletions

View File

@@ -16,8 +16,15 @@ AXOLOTL_LOGO = """
@@@@ @@@@@@@@@@@@@@@@
"""
HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():
HAS_PRINTED_LOGO = True
print(AXOLOTL_LOGO)

View File

@@ -48,6 +48,7 @@ def load_datasets(
*,
cfg: DictDefault,
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
@@ -56,6 +57,7 @@ def load_datasets(
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
@@ -77,20 +79,25 @@ def load_datasets(
preprocess_iterable=preprocess_iterable,
)
if cli_args and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
num_examples=num_examples,
text_only=text_only,
)
LOG.info("printing prompters...")

View File

@@ -21,6 +21,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
@@ -516,6 +517,8 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,