This commit is contained in:
Dan Saunders
2025-02-26 21:03:44 +00:00
parent a3224c7c3c
commit ed5178cd3d

View File

@@ -28,17 +28,11 @@ from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
# Optional imports with graceful fallbacks
try: try:
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
# Project setup
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -233,9 +227,10 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
# processes attempt to write the same file
if ( if (
state_dict_type == "SHARDED_STATE_DICT" state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
@@ -267,7 +262,6 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)