Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
@@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
# Only use mock for the commented out configs
if dataset_name is not None:
with patch(
"axolotl.utils.data.rl.load_dataset_w_config"
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue(
dataset_name
)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
else:
# Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset