From ecbe8b2b61bd24c8a6de662ad0ec3ca733feb4b1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Aug 2025 21:25:01 -0400 Subject: [PATCH] [GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073) * improve fsdp shard merging * improve logging * update information on merging and inferencing GPT-OSS * cleanup readme * automate cleanup of FSDP prefix * import GRPO only if necessary * only modify config.json on rank0 * merge final checkpoint at end of training * prevent circular import * Fix saving for sharded state dict * devx, move merged to output dir * move import back to top * Fix stuck merge * fix conditionals from pr feedback and add test --- examples/gpt-oss/README.md | 33 ++++++++- .../gpt-oss-120b-fft-fsdp2-offload.yaml | 1 + src/axolotl/cli/merge_sharded_fsdp_weights.py | 27 ++++++- src/axolotl/core/trainers/__init__.py | 1 - src/axolotl/train.py | 73 +++++++++++-------- src/axolotl/utils/train.py | 45 ++++++++++++ tests/utils/test_train.py | 24 ++++++ 7 files changed, 170 insertions(+), 34 deletions(-) create mode 100644 src/axolotl/utils/train.py create mode 100644 tests/utils/test_train.py diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md index 6dadb8230..9db5e9887 100644 --- a/examples/gpt-oss/README.md +++ b/examples/gpt-oss/README.md @@ -33,13 +33,44 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs. ### Training 120B -On 8xH100s +On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base +model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints. ```bash # FFT SFT with offloading (8x80GB @ ~49GiB/GPU) axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml ``` +ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. +See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. + +```bash +sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json +``` + +When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your +configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to +merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded +weights to `{output_dir}/merged`. + +```bash +axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/ +``` + + +### Inferencing your fine-tuned model + +GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 +for more information about using a special vllm-openai docker image for inferencing with vLLM. + +SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing +SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: + +```bash +python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8 +``` + ### Tool use GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning. diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index 4a9d51fdf..4b4fbd89b 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -20,6 +20,7 @@ datasets: dataset_prepared_path: last_run_prepared val_set_size: 0 output_dir: ./outputs/gpt-oss-out/ +save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2 sequence_len: 4096 sample_packing: true diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index c08d30ec8..c99f37fb1 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -10,6 +10,7 @@ import fire import torch import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +from accelerate import PartialState from accelerate.utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from axolotl.cli.config import load_cfg from axolotl.utils.logging import get_logger +from axolotl.utils.train import determine_last_checkpoint LOG = get_logger(__name__) @@ -143,7 +145,6 @@ def merge_fsdp_weights( ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) - from accelerate.state import PartialState if not is_torch_version(">=", "2.3.0"): raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") @@ -180,7 +181,6 @@ def merge_fsdp_weights( if remove_checkpoint_dir: LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") shutil.rmtree(checkpoint_dir_) - state.wait_for_everyone() def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): @@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False) + if checkpoint_dir: + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + raise ValueError( + f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}" + ) + + output_path = str(Path(parsed_cfg.output_dir) / "merged") merge_fsdp_weights( checkpoint_dir=str(fsdp_dir), - output_path=str(Path(parsed_cfg.output_dir) / "merged"), + output_path=output_path, safe_serialization=True, ) + state = PartialState() + state.wait_for_everyone() + LOG.info( + f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}", + main_process_only=True, + ) + LOG.info( + "Merged weights are only the safetensors and doesn't include the model configuration " + f"or tokenizer which may be found in {parsed_cfg.output_dir}.", + main_process_only=True, + ) if __name__ == "__main__": diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 5f97e387a..a9cda4efc 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -5,7 +5,6 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer from .trl import ( AxolotlCPOTrainer, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e8a2cbabe..8005389f1 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -4,11 +4,14 @@ from __future__ import annotations import importlib import inspect +import json import os +import shutil import signal import sys import typing import weakref +from collections import OrderedDict from contextlib import ExitStack from pathlib import Path from typing import Any, Dict @@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType +from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.trainer import setup_trainer try: @@ -46,7 +50,7 @@ except ImportError: BetterTransformer = None if typing.TYPE_CHECKING: - from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder LOG = get_logger(__name__) @@ -124,32 +128,6 @@ def setup_reference_model( return model_ref -def determine_resume_checkpoint(cfg: DictDefault) -> str | None: - """ - Determine the checkpoint to resume from based on configuration. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - - Returns: - Path to the checkpoint to resume from, or `None` if not resuming. - """ - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - return cfg.resume_from_checkpoint - - def setup_signal_handler( cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool ): @@ -282,12 +260,49 @@ def save_trained_model( else: state_dict_type = cfg.fsdp_config.state_dict_type trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) - trainer.save_model(cfg.output_dir) + trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT if state_dict_type == "SHARDED_STATE_DICT": LOG.info( "The final model was saved with a sharded state dict. Please ensure you merge " "the sharded weights with `merge-sharded-fsdp-weights`." ) + checkpoint_dir = determine_last_checkpoint(cfg, update=False) + if ( + not (Path(cfg.output_dir) / "model.safetensors.index.json").exists() + and checkpoint_dir + ): + # import here to prevent circular import + from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights + + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + merged_path = str(Path(cfg.output_dir) / "merged") + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=merged_path, + safe_serialization=True, + ) + trainer.accelerator.wait_for_everyone() + if trainer.accelerator.is_main_process: + # move all files in merged_path to cfg.output_dir + for merged_file in Path(merged_path).iterdir(): + shutil.move(str(merged_file), cfg.output_dir) + shutil.rmtree(merged_path) # remove what should be an empty dir + # TODO(wing):see https://github.com/huggingface/transformers/pull/40207 + # cleanup the FSDP prefix in the model config.json + if trainer.accelerator.is_main_process: + with open( + Path(cfg.output_dir) / "config.json", "r", encoding="utf-8" + ) as config_file_io: + # read the model config as an OrderedDict + config = json.load(config_file_io, object_pairs_hook=OrderedDict) + config["architectures"] = [ + name.lstrip("FSDP") for name in config["architectures"] + ] + # write the updated model config back + with open( + os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8" + ) as config_file_io: + json.dump(config, config_file_io, indent=2) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() @@ -564,7 +579,7 @@ def train( setup_model_card(cfg) # Execute the training - resume_from_checkpoint = determine_resume_checkpoint(cfg) + resume_from_checkpoint = determine_last_checkpoint(cfg) execute_training(cfg, trainer, resume_from_checkpoint) # clear cache diff --git a/src/axolotl/utils/train.py b/src/axolotl/utils/train.py new file mode 100644 index 000000000..1393459d9 --- /dev/null +++ b/src/axolotl/utils/train.py @@ -0,0 +1,45 @@ +"""Training utils for checkpoints""" + +from pathlib import Path + +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None: + """ + Determine the checkpoint to resume from based on configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + update: Whether to update the config with the determined checkpoint + + Returns: + Path to the checkpoint to resume from, or `None` if not resuming. + """ + last_checkpoint = None + checkpoints = sorted( + ( + p + for p in Path(cfg.output_dir).glob("checkpoint-*") + if p.name.split("-")[-1].isdigit() + ), + key=lambda p: int(p.name.split("-")[-1]), + ) + if checkpoints: + last_checkpoint = str(checkpoints[-1]) + if not update: + return last_checkpoint + + if ( + cfg.resume_from_checkpoint is None + and cfg.auto_resume_from_checkpoints + and last_checkpoint is not None + ): + cfg.resume_from_checkpoint = last_checkpoint + LOG.info( + f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + ) + return cfg.resume_from_checkpoint diff --git a/tests/utils/test_train.py b/tests/utils/test_train.py new file mode 100644 index 000000000..a1f6f6088 --- /dev/null +++ b/tests/utils/test_train.py @@ -0,0 +1,24 @@ +"""test for train checkpoint utils""" + +import os + +from axolotl.utils.dict import DictDefault +from axolotl.utils.train import determine_last_checkpoint + + +def test_determine_last_checkpoint(temp_dir): + cfg = DictDefault( + output_dir=temp_dir, + ) + for cpt_idx in [1, 9, 10, 20]: + os.makedirs( + os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True + ) + + last_checkpoint = determine_last_checkpoint(cfg, update=False) + assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20") + + cfg.resume_from_checkpoint = None + cfg.auto_resume_from_checkpoints = True + determine_last_checkpoint(cfg, update=True) + assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")