diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2f9a1afec..0375cf9db 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -112,7 +112,7 @@ def load_tokenized_prepared_datasets( raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: - ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) + ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0]