Pass weakref to model in the SIGINT handler to free up model post train function (#1581)

* Pass weakref to model in the SIGINT handler to free up model post train()

* Fix lint issues

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Chirag Jain
2024-05-03 20:35:28 +05:30
committed by GitHub
parent b9bb169602
commit dde02fcb94

View File

@@ -3,6 +3,7 @@
import os import os
import signal import signal
import sys import sys
import weakref
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@@ -127,14 +128,20 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model): def terminate_handler(_, __, model_weakref):
if cfg.flash_optimum and BetterTransformer: if model_weakref() is not None:
model = BetterTransformer.reverse(model) _model = model_weakref()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
sys.exit(0) sys.exit(0)
_model_weakref = weakref.ref(model)
signal.signal( signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
) )
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)""" badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""