diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 23388e40e..0dfeb0c7f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -229,6 +229,28 @@ def execute_training( PLUGIN_MANAGER.post_train(cfg, trainer.model) +def _rename_fsdp_merged_to_adapter(merged_dir: Path): + """Rename model*.safetensors files to adapter_model* in place. + + Also rewrites the index JSON weight_map if sharded output was produced. + """ + for file in sorted(merged_dir.iterdir()): + if file.name.startswith("model") and ".safetensors" in file.name: + file.rename(merged_dir / file.name.replace("model", "adapter_model", 1)) + + index = merged_dir / "adapter_model.safetensors.index.json" + if index.exists(): + data = json.loads(index.read_text(encoding="utf-8")) + if "weight_map" in data: + data["weight_map"] = { + k: v.replace("model", "adapter_model", 1) + for k, v in data["weight_map"].items() + } + index.write_text( + json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) + + def save_trained_model( cfg: DictDefault, trainer: Any, @@ -298,12 +320,17 @@ def save_trained_model( ) trainer.accelerator.wait_for_everyone() if trainer.accelerator.is_main_process: - # move all files in merged_path to cfg.output_dir + # FSDP checkpoints for PEFT only contain adapter weights; + # rename model* → adapter_model* so it loads correctly. + is_peft = cfg.adapter and not cfg.relora + if is_peft: + _rename_fsdp_merged_to_adapter(Path(merged_path)) for merged_file in Path(merged_path).iterdir(): - if (Path(cfg.output_dir) / merged_file.name).exists(): - (Path(cfg.output_dir) / merged_file.name).unlink() - shutil.move(str(merged_file), cfg.output_dir) - shutil.rmtree(merged_path) # remove what should be an empty dir + dest = Path(cfg.output_dir) / merged_file.name + if dest.exists(): + dest.unlink() + shutil.move(str(merged_file), dest) + shutil.rmtree(merged_path) # TODO(wing):see https://github.com/huggingface/transformers/pull/40207 # cleanup the FSDP prefix in the model config.json if trainer.accelerator.is_main_process: