fix: rename model to adapter_model for fsdp sharded final model (#3585)

* fix: rename model to adapter_model for fsdp sharded final model

* fix: follow upstream transformer shard size

* fix: handle multiple model files

* fix redundant condition, tighten to safetensors, keep shard size small

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-13 07:51:30 +07:00
committed by GitHub
parent 63a58cfec1
commit 6990478163

View File

@@ -229,6 +229,28 @@ def execute_training(
PLUGIN_MANAGER.post_train(cfg, trainer.model)
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
"""Rename model*.safetensors files to adapter_model* in place.
Also rewrites the index JSON weight_map if sharded output was produced.
"""
for file in sorted(merged_dir.iterdir()):
if file.name.startswith("model") and ".safetensors" in file.name:
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
index = merged_dir / "adapter_model.safetensors.index.json"
if index.exists():
data = json.loads(index.read_text(encoding="utf-8"))
if "weight_map" in data:
data["weight_map"] = {
k: v.replace("model", "adapter_model", 1)
for k, v in data["weight_map"].items()
}
index.write_text(
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
@@ -298,12 +320,17 @@ def save_trained_model(
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
# FSDP checkpoints for PEFT only contain adapter weights;
# rename model* → adapter_model* so it loads correctly.
is_peft = cfg.adapter and not cfg.relora
if is_peft:
_rename_fsdp_merged_to_adapter(Path(merged_path))
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
dest = Path(cfg.output_dir) / merged_file.name
if dest.exists():
dest.unlink()
shutil.move(str(merged_file), dest)
shutil.rmtree(merged_path)
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process: