chore: lint

This commit is contained in:
Wing Lian
2025-01-13 14:05:56 -05:00
parent e8fceb7091
commit 7232cbdeab
5 changed files with 17 additions and 8 deletions

View File

@@ -13,7 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True) download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"}) iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass @dataclass

View File

@@ -3,7 +3,7 @@
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Union
import fire import fire
import transformers import transformers

View File

@@ -63,7 +63,11 @@ def load_datasets(
""" """
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, cfg,

View File

@@ -3,7 +3,7 @@
import functools import functools
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union, Optional from typing import List, Optional, Tuple, Union
from datasets import ( from datasets import (
Dataset, Dataset,

View File

@@ -6,8 +6,8 @@ from pathlib import Path
import pytest import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1 from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli import load_datasets from axolotl.cli.args import TrainerCliArgs
from axolotl.common.cli import TrainerCliArgs from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -84,7 +84,7 @@ class TestKnowledgeDistillation:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
@@ -114,7 +114,7 @@ class TestKnowledgeDistillation:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"