diff --git a/scripts/finetune.py b/scripts/finetune.py index 944876a33..fe1d5b8e9 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -19,16 +19,11 @@ from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset +from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import ( - calculate_total_num_steps, - process_datasets_for_packing, - setup_trainer, -) +from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -39,7 +34,6 @@ configure_logging() LOG = logging.getLogger("axolotl.scripts") -DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" @@ -183,32 +177,7 @@ def train( if ( check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference ): # don't need to load dataset for these - if not cfg.pretraining_dataset: - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) - else: - train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, - tokenizer, - max_tokens=cfg.sequence_len, - seed=cfg.seed or 42, - ) - # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 - train_dataset = train_dataset.with_format("torch") - eval_dataset = None - - with zero_first(is_main_process()): - train_dataset, eval_dataset = process_datasets_for_packing( - cfg, train_dataset, eval_dataset - ) - if cfg.max_steps: - total_num_steps = min( - calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps - ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") - else: - total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) if cfg.debug or "debug" in kwargs: LOG.info("check_dataset_labels...") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6ebfb9d17..d64b06a10 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -42,8 +42,43 @@ from axolotl.prompters import ( SummarizeTLDRPrompter, ) from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.trainer import ( + calculate_total_num_steps, + process_datasets_for_packing, +) LOG = logging.getLogger("axolotl") +DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" + + +def prepare_dataset(cfg, tokenizer): + if not cfg.pretraining_dataset: + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + else: + train_dataset = load_pretraining_dataset( + cfg.pretraining_dataset, + tokenizer, + max_tokens=cfg.sequence_len, + seed=cfg.seed or 42, + ) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") + eval_dataset = None + + with zero_first(is_main_process()): + train_dataset, eval_dataset = process_datasets_for_packing( + cfg, train_dataset, eval_dataset + ) + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps + ) + LOG.info(f"Maximum number of steps set at {total_num_steps}") + else: + total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + return train_dataset, eval_dataset, total_num_steps def load_tokenized_prepared_datasets(