Ax art (#405)
* axolotl text art :D * only print art on rank0 * lint and pr feedback
This commit is contained in:
@@ -21,6 +21,7 @@ from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -33,10 +34,23 @@ sys.path.insert(0, src_dir)
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def print_axolotl_text_art():
|
||||
ascii_art = """
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
"""
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
@@ -146,6 +160,7 @@ def train(
|
||||
prepare_ds_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
print_axolotl_text_art()
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user