Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
@@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):