* improve fsdp shard merging * improve logging * update information on merging and inferencing GPT-OSS * cleanup readme * automate cleanup of FSDP prefix * import GRPO only if necessary * only modify config.json on rank0 * merge final checkpoint at end of training * prevent circular import * Fix saving for sharded state dict * devx, move merged to output dir * move import back to top * Fix stuck merge * fix conditionals from pr feedback and add test
25 lines
766 B
Python
25 lines
766 B
Python
"""test for train checkpoint utils"""
|
|
|
|
import os
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.train import determine_last_checkpoint
|
|
|
|
|
|
def test_determine_last_checkpoint(temp_dir):
|
|
cfg = DictDefault(
|
|
output_dir=temp_dir,
|
|
)
|
|
for cpt_idx in [1, 9, 10, 20]:
|
|
os.makedirs(
|
|
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
|
|
)
|
|
|
|
last_checkpoint = determine_last_checkpoint(cfg, update=False)
|
|
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
|
|
|
cfg.resume_from_checkpoint = None
|
|
cfg.auto_resume_from_checkpoints = True
|
|
determine_last_checkpoint(cfg, update=True)
|
|
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|