Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -7,7 +7,6 @@ import unittest
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)