improve handling of the prepared ds path and other cfg defaults (#701)
This commit is contained in:
@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
CLI to run training on a model
|
CLI to run training on a model
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from colorama import Fore
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -14,8 +16,11 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||||
|
msg = (
|
||||||
|
Fore.RED
|
||||||
|
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
LOG.warning(msg)
|
||||||
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
|||||||
5
src/axolotl/common/const.py
Normal file
5
src/axolotl/common/const.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
Various shared constants
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
@@ -16,6 +16,7 @@ from datasets import (
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
@@ -44,7 +45,6 @@ from axolotl.utils.trainer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
||||||
|
|
||||||
|
|
||||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
@@ -357,7 +357,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
LOG.info("shuffle merged datasets")
|
LOG.info("shuffle merged datasets")
|
||||||
dataset = dataset.shuffle(seed=seed)
|
dataset = dataset.shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
|
|||||||
Reference in New Issue
Block a user