From 5ff1322f32ea09544e3dd755e5992a66d8216298 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 10 Jan 2025 17:27:03 +0000 Subject: [PATCH] review comments --- src/axolotl/cli/args.py | 42 ++++++++++ src/axolotl/cli/art.py | 35 -------- src/axolotl/cli/checks.py | 4 +- src/axolotl/cli/config.py | 4 +- src/axolotl/cli/evaluate.py | 8 +- src/axolotl/cli/inference.py | 6 +- src/axolotl/cli/main.py | 2 +- src/axolotl/cli/merge_lora.py | 8 +- src/axolotl/cli/merge_sharded_fsdp_weights.py | 2 +- src/axolotl/cli/preprocess.py | 8 +- src/axolotl/cli/train.py | 8 +- src/axolotl/cli/utils.py | 54 +++++++++++-- src/axolotl/common/cli.py | 79 ------------------- src/axolotl/common/datasets.py | 8 +- src/axolotl/evaluate.py | 2 +- tests/e2e/test_dpo.py | 18 ++--- 16 files changed, 130 insertions(+), 158 deletions(-) create mode 100644 src/axolotl/cli/args.py delete mode 100644 src/axolotl/common/cli.py diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py new file mode 100644 index 000000000..9ddef5230 --- /dev/null +++ b/src/axolotl/cli/args.py @@ -0,0 +1,42 @@ +"""Module for axolotl CLI command arguments.""" + +from dataclasses import dataclass, field + + +@dataclass +class PreprocessCliArgs: + """Dataclass with CLI arguments for `axolotl preprocess` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: str | None = field(default=None) + download: bool | None = field(default=True) + + +@dataclass +class TrainerCliArgs: + """Dataclass with CLI arguments for `axolotl train` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + merge_lora: bool = field(default=False) + prompter: str | None = field(default=None) + shard: bool = field(default=False) + + +@dataclass +class EvaluateCliArgs: + """Dataclass with CLI arguments for `axolotl evaluate` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + + +@dataclass +class InferenceCliArgs: + """Dataclass with CLI arguments for `axolotl inference` command.""" + + prompter: str | None = field(default=None) diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py index d6b225b9f..6ed22a52d 100644 --- a/src/axolotl/cli/art.py +++ b/src/axolotl/cli/art.py @@ -1,8 +1,5 @@ """Axolotl ASCII logo utils.""" -from art import text2art -from transformers.utils.import_utils import _is_package_available - from axolotl.utils.distributed import is_main_process AXOLOTL_LOGO = """ @@ -20,38 +17,6 @@ AXOLOTL_LOGO = """ """ -def print_dep_versions(): - """Prints versions of various axolotl dependencies.""" - packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] - max_len = max(len(pkg) for pkg in packages) - if is_main_process(): - print("*" * 40) - print("**** Axolotl Dependency Versions *****") - for pkg in packages: - pkg_version = _is_package_available(pkg, return_version=True) - print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") - print("*" * 40) - - -def print_legacy_axolotl_text_art(suffix=None): - """ - Prints axolotl ASCII art and dependency versions. - - Args: - suffix: Text to append to ASCII art text. - """ - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(ascii_text, font=font) - - if is_main_process(): - print(ascii_art) - - print_dep_versions() - - def print_axolotl_text_art(): """Prints axolotl ASCII art.""" if is_main_process(): diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index c450a1cf6..cc3ed0d9f 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -14,7 +14,7 @@ configure_logging() LOG = logging.getLogger(__name__) -def check_accelerate_default_config(): +def check_accelerate_default_config() -> None: """Logs at warning level if no accelerate config file is found.""" if Path(config_args.default_yaml_config_file).exists(): LOG.warning( @@ -22,7 +22,7 @@ def check_accelerate_default_config(): ) -def check_user_token(): +def check_user_token() -> bool: """Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1. Returns: diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 1df664ff8..166a67670 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -94,7 +94,7 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: raise err -def choose_config(path: Path): +def choose_config(path: Path) -> str: """ Helper method for choosing a `axolotl` config YAML file (considering only files ending with `.yml` or `.yaml`). If more than one config file exists in the passed @@ -152,7 +152,7 @@ def prepare_plugins(cfg: DictDefault): plugin_manager.register(plugin_name) -def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): +def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: """ Loads the `axolotl` configuration stored at `config`, validates it, and performs various setup. diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 981d0fa64..5b430b31e 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -8,11 +8,11 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser +from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs -from axolotl.common.datasets import load_datasets, load_rl_datasets +from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.evaluate import evaluate from axolotl.utils.dict import DictDefault @@ -34,8 +34,8 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 88cd6e64b..e11a39bd6 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -13,9 +13,10 @@ import transformers from dotenv import load_dotenv from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer +from axolotl.cli.args import InferenceCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg -from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer +from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.chat_templates import ( get_chat_template, get_chat_template_from_config, @@ -33,10 +34,11 @@ def get_multi_line_input() -> str: Possibly multi-line, possibly empty stdin input as a string. """ print("Give me an instruction (Ctrl + D to submit): ") + instruction = "" for line in sys.stdin: instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() + return instruction diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index ac55501a4..43e2de3db 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -7,6 +7,7 @@ from typing import Optional import click import axolotl +from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, @@ -14,7 +15,6 @@ from axolotl.cli.utils import ( fetch_from_github, filter_none_kwargs, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 2cada4862..595eb3eab 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -8,9 +8,10 @@ import fire import transformers from dotenv import load_dotenv +from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) @@ -31,10 +32,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: LOG.info("Running merge of LoRA with base model...") model = model.merge_and_unload(progressbar=True) - try: - model.to(dtype=cfg.torch_dtype) - except RuntimeError: - pass + model.to(dtype=cfg.torch_dtype) model.generation_config.do_sample = True if cfg.local_rank == 0: diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 9b297b872..d4b36d92c 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -24,9 +24,9 @@ from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner +from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs LOG = logging.getLogger(__name__) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 924695277..18f87acf5 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -12,12 +12,12 @@ from colorama import Fore from dotenv import load_dotenv from transformers import AutoModelForCausalLM +from axolotl.cli.args import PreprocessCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.common.datasets import load_datasets, load_rl_datasets +from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching @@ -47,8 +47,8 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH with disable_datasets_caching(): - if cfg.rl: # and cfg.rl != "orpo": - load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + load_dpo_datasets(cfg=cfg, cli_args=cli_args) else: load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 9bd04919e..320a40153 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -8,11 +8,11 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser +from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs -from axolotl.common.datasets import load_datasets, load_rl_datasets +from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train from axolotl.utils.dict import DictDefault @@ -34,8 +34,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f04304759..435637688 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -8,16 +8,34 @@ import logging from functools import wraps from pathlib import Path from types import NoneType -from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + get_args, + get_origin, +) import click import requests from pydantic import BaseModel +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +configure_logging() LOG = logging.getLogger(__name__) -def filter_none_kwargs(func): +def filter_none_kwargs(func: Callable) -> Callable: """ Wraps function to remove `None`-valued `kwargs`. @@ -29,15 +47,16 @@ def filter_none_kwargs(func): """ @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Callable: """Filters out `None`-valued `kwargs`.""" filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + return func(*args, **filtered_kwargs) return wrapper -def add_options_from_dataclass(config_class: Type[Any]): +def add_options_from_dataclass(config_class: Type[Any]) -> Callable: """ Create Click options from the fields of a dataclass. @@ -75,7 +94,7 @@ def add_options_from_dataclass(config_class: Type[Any]): return decorator -def add_options_from_config(config_class: Type[BaseModel]): +def add_options_from_config(config_class: Type[BaseModel]) -> Callable: """ Create Click options from the fields of a Pydantic model. @@ -256,3 +275,28 @@ def fetch_from_github( LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") if files_processed["error"]: LOG.info(f"Failed files: {len(files_processed['error'])}") + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + inference: bool = False, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: + """ + Helper function for loading a model and tokenizer specified in the given `axolotl` + config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + `transformers` model and tokenizer. + """ + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + LOG.info("loading model...") + model, _ = load_model(cfg, tokenizer, inference=inference) + + return model, tokenizer diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py deleted file mode 100644 index f714d7d4b..000000000 --- a/src/axolotl/common/cli.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Shared module for CLI specific utilities.""" - -import logging -from dataclasses import dataclass, field -from typing import Any, Optional, Tuple - -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast - -import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 -from axolotl.logging_config import configure_logging -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer - -configure_logging() -LOG = logging.getLogger(__name__) - - -@dataclass -class PreprocessCliArgs: - """Dataclass with CLI arguments for `axolotl preprocess` command.""" - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - download: Optional[bool] = field(default=True) - - -@dataclass -class TrainerCliArgs: - """Dataclass with CLI arguments for `axolotl train` command.""" - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - merge_lora: bool = field(default=False) - prompter: Optional[str] = field(default=None) - shard: bool = field(default=False) - - -@dataclass -class EvaluateCliArgs: - """Dataclass with CLI arguments for `axolotl evaluate` command.""" - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - - -@dataclass -class InferenceCliArgs: - """Dataclass with CLI arguments for `axolotl inference` command.""" - - prompter: Optional[str] = field(default=None) - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - inference: bool = False, -) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: - """ - Helper function for loading a model and tokenizer specified in the given `axolotl` - config. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - inference: Boolean denoting inference mode. - - Returns: - `transformers` model and tokenizer. - """ - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - LOG.info("loading model...") - model, _ = load_model(cfg, tokenizer, inference=inference) - - return model, tokenizer diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index a4d568c94..ebe82ac64 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -8,7 +8,7 @@ from typing import Optional, Union from datasets import Dataset -from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs +from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -27,7 +27,7 @@ class TrainDatasetMeta: total_num_steps: Optional[int] = None -def sample_dataset(dataset: Dataset, num_samples: int): +def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: """ Randomly sample `num_samples` samples from `dataset`. @@ -96,7 +96,7 @@ def load_datasets( ) -def load_rl_datasets( +def load_dpo_datasets( *, cfg: DictDefault, cli_args: Union[ @@ -104,7 +104,7 @@ def load_rl_datasets( ], # pylint: disable=unused-argument ) -> TrainDatasetMeta: """ - Loads one or more training or evaluation datasets for RL training, calling + Loads one or more training or evaluation datasets for DPO training, calling `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug information. diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index fec7e9ecb..8d9ddc6ab 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -66,7 +66,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f Evaluate a model on training and validation datasets Args: - cfg: Config dictionary. + cfg: Dictionary mapping `axolotl` config keys to values. dataset_meta: Dataset metadata containing training and evaluation datasets. Returns: diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 3a6c00ab0..45c38ecff 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -9,8 +9,8 @@ from pathlib import Path import pytest -from axolotl.common.cli import TrainerCliArgs -from axolotl.common.datasets import load_rl_datasets +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_dpo_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)