chore: lint
This commit is contained in:
@@ -13,7 +13,12 @@ class PreprocessCliArgs:
|
|||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=1)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
download: Optional[bool] = field(default=True)
|
download: Optional[bool] = field(default=True)
|
||||||
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
|
iterable: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -63,7 +63,11 @@ def load_datasets(
|
|||||||
"""
|
"""
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
|
preprocess_iterable = (
|
||||||
|
hasattr(cli_args, "iterable")
|
||||||
|
and cli_args.iterable is not None
|
||||||
|
and cli_args.iterable
|
||||||
|
)
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Union, Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -84,7 +84,7 @@ class TestKnowledgeDistillation:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
@@ -114,7 +114,7 @@ class TestKnowledgeDistillation:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
|
|||||||
Reference in New Issue
Block a user