diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index dbc4172b4..6083e30be 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -494,7 +494,9 @@ def load_prepare_datasets( test_fingerprint = md5(to_hash_test) dataset = dataset.train_test_split( - test_size=cfg.val_set_size, + test_size=int(cfg.val_set_size) + if cfg.val_set_size == int(cfg.val_set_size) + else cfg.val_set_size, shuffle=False, seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint,