[GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging * improve logging * update information on merging and inferencing GPT-OSS * cleanup readme * automate cleanup of FSDP prefix * import GRPO only if necessary * only modify config.json on rank0 * merge final checkpoint at end of training * prevent circular import * Fix saving for sharded state dict * devx, move merged to output dir * move import back to top * Fix stuck merge * fix conditionals from pr feedback and add test
This commit is contained in:
24
tests/utils/test_train.py
Normal file
24
tests/utils/test_train.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""test for train checkpoint utils"""
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
|
||||
def test_determine_last_checkpoint(temp_dir):
|
||||
cfg = DictDefault(
|
||||
output_dir=temp_dir,
|
||||
)
|
||||
for cpt_idx in [1, 9, 10, 20]:
|
||||
os.makedirs(
|
||||
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
|
||||
)
|
||||
|
||||
last_checkpoint = determine_last_checkpoint(cfg, update=False)
|
||||
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
|
||||
cfg.resume_from_checkpoint = None
|
||||
cfg.auto_resume_from_checkpoints = True
|
||||
determine_last_checkpoint(cfg, update=True)
|
||||
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
Reference in New Issue
Block a user