[GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging * improve logging * update information on merging and inferencing GPT-OSS * cleanup readme * automate cleanup of FSDP prefix * import GRPO only if necessary * only modify config.json on rank0 * merge final checkpoint at end of training * prevent circular import * Fix saving for sharded state dict * devx, move merged to output dir * move import back to top * Fix stuck merge * fix conditionals from pr feedback and add test
This commit is contained in:
@@ -33,13 +33,44 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
|||||||
|
|
||||||
### Training 120B
|
### Training 120B
|
||||||
|
|
||||||
On 8xH100s
|
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
|
||||||
|
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||||
|
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
|
||||||
|
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
|
||||||
|
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
|
||||||
|
weights to `{output_dir}/merged`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||||
|
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Inferencing your fine-tuned model
|
||||||
|
|
||||||
|
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||||
|
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||||
|
|
||||||
|
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||||
|
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
|
||||||
|
```
|
||||||
|
|
||||||
### Tool use
|
### Tool use
|
||||||
|
|
||||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import fire
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed.checkpoint as dist_cp
|
import torch.distributed.checkpoint as dist_cp
|
||||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||||
|
from accelerate import PartialState
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
|||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
|
|||||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||||
"""
|
"""
|
||||||
checkpoint_dir_ = Path(checkpoint_dir)
|
checkpoint_dir_ = Path(checkpoint_dir)
|
||||||
from accelerate.state import PartialState
|
|
||||||
|
|
||||||
if not is_torch_version(">=", "2.3.0"):
|
if not is_torch_version(">=", "2.3.0"):
|
||||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||||
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
|
|||||||
if remove_checkpoint_dir:
|
if remove_checkpoint_dir:
|
||||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||||
shutil.rmtree(checkpoint_dir_)
|
shutil.rmtree(checkpoint_dir_)
|
||||||
state.wait_for_everyone()
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
@@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
if not fsdp_dir.exists():
|
||||||
|
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||||
|
if checkpoint_dir:
|
||||||
|
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||||
|
if not fsdp_dir.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||||
merge_fsdp_weights(
|
merge_fsdp_weights(
|
||||||
checkpoint_dir=str(fsdp_dir),
|
checkpoint_dir=str(fsdp_dir),
|
||||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
output_path=output_path,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
)
|
)
|
||||||
|
state = PartialState()
|
||||||
|
state.wait_for_everyone()
|
||||||
|
LOG.info(
|
||||||
|
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||||
|
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
from .base import AxolotlTrainer
|
from .base import AxolotlTrainer
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
|
||||||
from .mamba import AxolotlMambaTrainer
|
from .mamba import AxolotlMambaTrainer
|
||||||
from .trl import (
|
from .trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
|
from collections import OrderedDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
@@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed
|
|||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -46,7 +50,7 @@ except ImportError:
|
|||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -124,32 +128,6 @@ def setup_reference_model(
|
|||||||
return model_ref
|
return model_ref
|
||||||
|
|
||||||
|
|
||||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|
||||||
"""
|
|
||||||
Determine the checkpoint to resume from based on configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
|
||||||
"""
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
|
||||||
possible_checkpoints = [
|
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
|
||||||
]
|
|
||||||
if len(possible_checkpoints) > 0:
|
|
||||||
sorted_paths = sorted(
|
|
||||||
possible_checkpoints,
|
|
||||||
key=lambda path: int(path.split("-")[-1]),
|
|
||||||
)
|
|
||||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
|
||||||
LOG.info(
|
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
|
||||||
)
|
|
||||||
return cfg.resume_from_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def setup_signal_handler(
|
def setup_signal_handler(
|
||||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||||
):
|
):
|
||||||
@@ -282,12 +260,49 @@ def save_trained_model(
|
|||||||
else:
|
else:
|
||||||
state_dict_type = cfg.fsdp_config.state_dict_type
|
state_dict_type = cfg.fsdp_config.state_dict_type
|
||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||||
trainer.save_model(cfg.output_dir)
|
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
|
||||||
if state_dict_type == "SHARDED_STATE_DICT":
|
if state_dict_type == "SHARDED_STATE_DICT":
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"The final model was saved with a sharded state dict. Please ensure you merge "
|
"The final model was saved with a sharded state dict. Please ensure you merge "
|
||||||
"the sharded weights with `merge-sharded-fsdp-weights`."
|
"the sharded weights with `merge-sharded-fsdp-weights`."
|
||||||
)
|
)
|
||||||
|
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
|
||||||
|
if (
|
||||||
|
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
|
||||||
|
and checkpoint_dir
|
||||||
|
):
|
||||||
|
# import here to prevent circular import
|
||||||
|
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
|
||||||
|
|
||||||
|
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||||
|
merged_path = str(Path(cfg.output_dir) / "merged")
|
||||||
|
merge_fsdp_weights(
|
||||||
|
checkpoint_dir=str(fsdp_dir),
|
||||||
|
output_path=merged_path,
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
# move all files in merged_path to cfg.output_dir
|
||||||
|
for merged_file in Path(merged_path).iterdir():
|
||||||
|
shutil.move(str(merged_file), cfg.output_dir)
|
||||||
|
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||||
|
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||||
|
# cleanup the FSDP prefix in the model config.json
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
with open(
|
||||||
|
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
|
||||||
|
) as config_file_io:
|
||||||
|
# read the model config as an OrderedDict
|
||||||
|
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
|
||||||
|
config["architectures"] = [
|
||||||
|
name.lstrip("FSDP") for name in config["architectures"]
|
||||||
|
]
|
||||||
|
# write the updated model config back
|
||||||
|
with open(
|
||||||
|
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
|
||||||
|
) as config_file_io:
|
||||||
|
json.dump(config, config_file_io, indent=2)
|
||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
@@ -564,7 +579,7 @@ def train(
|
|||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
resume_from_checkpoint = determine_last_checkpoint(cfg)
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
# clear cache
|
# clear cache
|
||||||
|
|||||||
45
src/axolotl/utils/train.py
Normal file
45
src/axolotl/utils/train.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Training utils for checkpoints"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None:
|
||||||
|
"""
|
||||||
|
Determine the checkpoint to resume from based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
update: Whether to update the config with the determined checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||||
|
"""
|
||||||
|
last_checkpoint = None
|
||||||
|
checkpoints = sorted(
|
||||||
|
(
|
||||||
|
p
|
||||||
|
for p in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
|
if p.name.split("-")[-1].isdigit()
|
||||||
|
),
|
||||||
|
key=lambda p: int(p.name.split("-")[-1]),
|
||||||
|
)
|
||||||
|
if checkpoints:
|
||||||
|
last_checkpoint = str(checkpoints[-1])
|
||||||
|
if not update:
|
||||||
|
return last_checkpoint
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.resume_from_checkpoint is None
|
||||||
|
and cfg.auto_resume_from_checkpoints
|
||||||
|
and last_checkpoint is not None
|
||||||
|
):
|
||||||
|
cfg.resume_from_checkpoint = last_checkpoint
|
||||||
|
LOG.info(
|
||||||
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
|
)
|
||||||
|
return cfg.resume_from_checkpoint
|
||||||
24
tests/utils/test_train.py
Normal file
24
tests/utils/test_train.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""test for train checkpoint utils"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_determine_last_checkpoint(temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
output_dir=temp_dir,
|
||||||
|
)
|
||||||
|
for cpt_idx in [1, 9, 10, 20]:
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
last_checkpoint = determine_last_checkpoint(cfg, update=False)
|
||||||
|
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||||
|
|
||||||
|
cfg.resume_from_checkpoint = None
|
||||||
|
cfg.auto_resume_from_checkpoints = True
|
||||||
|
determine_last_checkpoint(cfg, update=True)
|
||||||
|
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
|
||||||
Reference in New Issue
Block a user