limit num_proc when saving datasets to disk (#2948) [skip ci]

* limit num_proc when saving datasets to disk

* enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save

* update fixtures with dataset processes since that should never be NoneType

* improve reusability for tests
This commit is contained in:
Wing Lian
2025-07-21 11:39:38 -04:00
committed by GitHub
parent 8e5f146701
commit db5f6f4693
7 changed files with 27 additions and 9 deletions

View File

@@ -82,6 +82,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
"dataset_processes": 4,
}
)
@@ -440,6 +441,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
cfg["dataset_processes"] = 4
if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"

View File

@@ -141,6 +141,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
@@ -179,6 +180,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
@@ -217,6 +219,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
@@ -249,6 +252,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
@@ -281,6 +285,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
@@ -365,6 +370,7 @@ class TestDatasetPreparation:
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_processes": 4,
}
)
@@ -466,6 +472,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

View File

@@ -210,6 +210,7 @@ class TestDeduplicateRLDataset:
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
"dataset_processes": 4,
}
)
yield fixture

View File

@@ -99,6 +99,7 @@ class TestPacking(unittest.TestCase):
"type": "alpaca",
},
],
"dataset_processes": 4,
"num_epochs": 1,
"max_steps": 20,
"save_steps": 10,