[GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging * improve logging * update information on merging and inferencing GPT-OSS * cleanup readme * automate cleanup of FSDP prefix * import GRPO only if necessary * only modify config.json on rank0 * merge final checkpoint at end of training * prevent circular import * Fix saving for sharded state dict * devx, move merged to output dir * move import back to top * Fix stuck merge * fix conditionals from pr feedback and add test
This commit is contained in:
@@ -33,13 +33,44 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
||||
|
||||
### Training 120B
|
||||
|
||||
On 8xH100s
|
||||
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
|
||||
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
|
||||
|
||||
```bash
|
||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
```
|
||||
|
||||
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||
|
||||
```bash
|
||||
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
|
||||
```
|
||||
|
||||
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
|
||||
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
|
||||
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
|
||||
weights to `{output_dir}/merged`.
|
||||
|
||||
```bash
|
||||
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||
|
||||
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
|
||||
```
|
||||
|
||||
### Tool use
|
||||
|
||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||
|
||||
@@ -20,6 +20,7 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -10,6 +10,7 @@ import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
|
||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||
"""
|
||||
checkpoint_dir_ = Path(checkpoint_dir)
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if not is_torch_version(">=", "2.3.0"):
|
||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
|
||||
if remove_checkpoint_dir:
|
||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||
shutil.rmtree(checkpoint_dir_)
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
@@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||
if checkpoint_dir:
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
raise ValueError(
|
||||
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||
)
|
||||
|
||||
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
|
||||
@@ -4,11 +4,14 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
@@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
@@ -46,7 +50,7 @@ except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -124,32 +128,6 @@ def setup_reference_model(
|
||||
return model_ref
|
||||
|
||||
|
||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||
"""
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
@@ -282,12 +260,49 @@ def save_trained_model(
|
||||
else:
|
||||
state_dict_type = cfg.fsdp_config.state_dict_type
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
trainer.save_model(cfg.output_dir)
|
||||
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
|
||||
if state_dict_type == "SHARDED_STATE_DICT":
|
||||
LOG.info(
|
||||
"The final model was saved with a sharded state dict. Please ensure you merge "
|
||||
"the sharded weights with `merge-sharded-fsdp-weights`."
|
||||
)
|
||||
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
|
||||
if (
|
||||
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
|
||||
and checkpoint_dir
|
||||
):
|
||||
# import here to prevent circular import
|
||||
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
|
||||
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
merged_path = str(Path(cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=merged_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
# move all files in merged_path to cfg.output_dir
|
||||
for merged_file in Path(merged_path).iterdir():
|
||||
shutil.move(str(merged_file), cfg.output_dir)
|
||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||
# cleanup the FSDP prefix in the model config.json
|
||||
if trainer.accelerator.is_main_process:
|
||||
with open(
|
||||
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
|
||||
) as config_file_io:
|
||||
# read the model config as an OrderedDict
|
||||
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
|
||||
config["architectures"] = [
|
||||
name.lstrip("FSDP") for name in config["architectures"]
|
||||
]
|
||||
# write the updated model config back
|
||||
with open(
|
||||
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
|
||||
) as config_file_io:
|
||||
json.dump(config, config_file_io, indent=2)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
@@ -564,7 +579,7 @@ def train(
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
resume_from_checkpoint = determine_last_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# clear cache
|
||||
|
||||
45
src/axolotl/utils/train.py
Normal file
45
src/axolotl/utils/train.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Training utils for checkpoints"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
update: Whether to update the config with the determined checkpoint
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||
"""
|
||||
last_checkpoint = None
|
||||
checkpoints = sorted(
|
||||
(
|
||||
p
|
||||
for p in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
if p.name.split("-")[-1].isdigit()
|
||||
),
|
||||
key=lambda p: int(p.name.split("-")[-1]),
|
||||
)
|
||||
if checkpoints:
|
||||
last_checkpoint = str(checkpoints[-1])
|
||||
if not update:
|
||||
return last_checkpoint
|
||||
|
||||
if (
|
||||
cfg.resume_from_checkpoint is None
|
||||
and cfg.auto_resume_from_checkpoints
|
||||
and last_checkpoint is not None
|
||||
):
|
||||
cfg.resume_from_checkpoint = last_checkpoint
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
return cfg.resume_from_checkpoint
|
||||
24
tests/utils/test_train.py
Normal file
24
tests/utils/test_train.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""test for train checkpoint utils"""
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
|
||||
def test_determine_last_checkpoint(temp_dir):
|
||||
cfg = DictDefault(
|
||||
output_dir=temp_dir,
|
||||
)
|
||||
for cpt_idx in [1, 9, 10, 20]:
|
||||
os.makedirs(
|
||||
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
|
||||
)
|
||||
|
||||
last_checkpoint = determine_last_checkpoint(cfg, update=False)
|
||||
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
|
||||
cfg.resume_from_checkpoint = None
|
||||
cfg.auto_resume_from_checkpoints = True
|
||||
determine_last_checkpoint(cfg, update=True)
|
||||
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||
Reference in New Issue
Block a user