chore: lint

This commit is contained in:
Wing Lian
2025-01-13 14:05:56 -05:00
parent e8fceb7091
commit 7232cbdeab
5 changed files with 17 additions and 8 deletions

View File

@@ -13,7 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass

View File

@@ -3,7 +3,7 @@
import logging
import warnings
from pathlib import Path
from typing import Optional, Union
from typing import Union
import fire
import transformers

View File

@@ -63,7 +63,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,

View File

@@ -3,7 +3,7 @@
import functools
import logging
from pathlib import Path
from typing import List, Tuple, Union, Optional
from typing import List, Optional, Tuple, Union
from datasets import (
Dataset,

View File

@@ -6,8 +6,8 @@ from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@@ -84,7 +84,7 @@ class TestKnowledgeDistillation:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
@@ -114,7 +114,7 @@ class TestKnowledgeDistillation:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"