fix bettertransformers save, force it to skip after saving correctly in callback

This commit is contained in:
Wing Lian
2023-06-01 00:33:13 -04:00
parent 86bd9fcff4
commit a32cc1d021
3 changed files with 26 additions and 11 deletions

View File

@@ -9,7 +9,7 @@ from transformers import (
TrainerState, TrainerState,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
class SaveBetterTransformerModelCallback( class SaveBetterTransformerModelCallback(
TrainerCallback TrainerCallback
): # pylint: disable=too-few-public-methods ): # pylint: disable=too-few-public-methods
"""Callback to save the BatterTransformer wrapped model""" """Callback to save the BetterTransformer wrapped model"""
def on_save( def on_step_end(
self, self,
args: TrainingArguments, args: TrainingArguments,
state: TrainerState, state: TrainerState,
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): ):
checkpoint_folder = os.path.join( # Save
args.output_dir, if (
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", args.save_strategy == IntervalStrategy.STEPS
) and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
model = BetterTransformer.reverse(kwargs["model"]) if control.should_save:
model.save_pretrained(checkpoint_folder) checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control return control

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging
import math import math
import os import os
import sys import sys
@@ -229,6 +230,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
logging.info("Setting up SaveBetterTransformerModelCallback.")
callbacks.append(SaveBetterTransformerModelCallback) callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {

View File

@@ -57,9 +57,10 @@ def validate_config(cfg):
) )
if cfg.fp16 or cfg.bf16: if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer") raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True: if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning( logging.warning(
"You should probably set float16 to true to load the model in float16 for BetterTransformers" "You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
) )
if int(torch.__version__.split(".")[0]) < 2: if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required") logging.warning("torch>=2.0.0 required")