Files
axolotl/src/axolotl/utils/train.py
Wing Lian ecbe8b2b61 [GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging

* improve logging

* update information on merging and inferencing GPT-OSS

* cleanup readme

* automate cleanup of FSDP prefix

* import GRPO only if necessary

* only modify config.json on rank0

* merge final checkpoint at end of training

* prevent circular import

* Fix saving for sharded state dict

* devx, move merged to output dir

* move import back to top

* Fix stuck merge

* fix conditionals from pr feedback and add test
2025-08-15 21:25:01 -04:00

46 lines
1.3 KiB
Python

"""Training utils for checkpoints"""
from pathlib import Path
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
update: Whether to update the config with the determined checkpoint
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
last_checkpoint = None
checkpoints = sorted(
(
p
for p in Path(cfg.output_dir).glob("checkpoint-*")
if p.name.split("-")[-1].isdigit()
),
key=lambda p: int(p.name.split("-")[-1]),
)
if checkpoints:
last_checkpoint = str(checkpoints[-1])
if not update:
return last_checkpoint
if (
cfg.resume_from_checkpoint is None
and cfg.auto_resume_from_checkpoints
and last_checkpoint is not None
):
cfg.resume_from_checkpoint = last_checkpoint
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint