diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index f3a0d539b..d451b45a6 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -13,7 +13,12 @@ class PreprocessCliArgs: debug_num_examples: int = field(default=1) prompter: Optional[str] = field(default=None) download: Optional[bool] = field(default=True) - iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"}) + iterable: Optional[bool] = field( + default=None, + metadata={ + "help": "Use IterableDataset for streaming processing of large datasets" + }, + ) @dataclass diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 627a95f8f..5585c88a7 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -3,7 +3,7 @@ import logging import warnings from pathlib import Path -from typing import Optional, Union +from typing import Union import fire import transformers diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 8694f0986..0c6704dc7 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -63,7 +63,11 @@ def load_datasets( """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable + preprocess_iterable = ( + hasattr(cli_args, "iterable") + and cli_args.iterable is not None + and cli_args.iterable + ) train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( cfg, diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 59d862a7f..4706b5a31 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,7 +3,7 @@ import functools import logging from pathlib import Path -from typing import List, Tuple, Union, Optional +from typing import List, Optional, Tuple, Union from datasets import ( Dataset, diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 9e4103ecd..18987b68b 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -6,8 +6,8 @@ from pathlib import Path import pytest from e2e.utils import check_tensorboard, require_torch_2_5_1 -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault @@ -84,7 +84,7 @@ class TestKnowledgeDistillation: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" @@ -114,7 +114,7 @@ class TestKnowledgeDistillation: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"