Pass weakref to model in the SIGINT handler to free up model post train function (#1581)
* Pass weakref to model in the SIGINT handler to free up model post train() * Fix lint issues * chore: lint --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
@@ -127,14 +128,20 @@ def train(
|
|||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|
||||||
def terminate_handler(_, __, model):
|
def terminate_handler(_, __, model_weakref):
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if model_weakref() is not None:
|
||||||
model = BetterTransformer.reverse(model)
|
_model = model_weakref()
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
|
_model = BetterTransformer.reverse(_model)
|
||||||
|
_model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
_model_weakref = weakref.ref(model)
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
signal.SIGINT,
|
||||||
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
|
|||||||
Reference in New Issue
Block a user