fix: rename model to adapter_model for fsdp sharded final model (#3585)
* fix: rename model to adapter_model for fsdp sharded final model * fix: follow upstream transformer shard size * fix: handle multiple model files * fix redundant condition, tighten to safetensors, keep shard size small --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -229,6 +229,28 @@ def execute_training(
|
||||
PLUGIN_MANAGER.post_train(cfg, trainer.model)
|
||||
|
||||
|
||||
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
|
||||
"""Rename model*.safetensors files to adapter_model* in place.
|
||||
|
||||
Also rewrites the index JSON weight_map if sharded output was produced.
|
||||
"""
|
||||
for file in sorted(merged_dir.iterdir()):
|
||||
if file.name.startswith("model") and ".safetensors" in file.name:
|
||||
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
|
||||
|
||||
index = merged_dir / "adapter_model.safetensors.index.json"
|
||||
if index.exists():
|
||||
data = json.loads(index.read_text(encoding="utf-8"))
|
||||
if "weight_map" in data:
|
||||
data["weight_map"] = {
|
||||
k: v.replace("model", "adapter_model", 1)
|
||||
for k, v in data["weight_map"].items()
|
||||
}
|
||||
index.write_text(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
@@ -298,12 +320,17 @@ def save_trained_model(
|
||||
)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
# move all files in merged_path to cfg.output_dir
|
||||
# FSDP checkpoints for PEFT only contain adapter weights;
|
||||
# rename model* → adapter_model* so it loads correctly.
|
||||
is_peft = cfg.adapter and not cfg.relora
|
||||
if is_peft:
|
||||
_rename_fsdp_merged_to_adapter(Path(merged_path))
|
||||
for merged_file in Path(merged_path).iterdir():
|
||||
if (Path(cfg.output_dir) / merged_file.name).exists():
|
||||
(Path(cfg.output_dir) / merged_file.name).unlink()
|
||||
shutil.move(str(merged_file), cfg.output_dir)
|
||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||
dest = Path(cfg.output_dir) / merged_file.name
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
shutil.move(str(merged_file), dest)
|
||||
shutil.rmtree(merged_path)
|
||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||
# cleanup the FSDP prefix in the model config.json
|
||||
if trainer.accelerator.is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user