fix bettertransformers save, force it to skip after saving correctly in callback

This commit is contained in:
Wing Lian
2023-06-01 00:33:13 -04:00
parent 1210dc8fd5
commit 1a82082e91
3 changed files with 25 additions and 11 deletions

View File

@@ -9,7 +9,7 @@ from transformers import (
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BatterTransformer wrapped model"""
"""Callback to save the BetterTransformer wrapped model"""
def on_save(
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
logging.info("Setting up SaveBetterTransformerModelCallback.")
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {

View File

@@ -66,9 +66,10 @@ def validate_config(cfg):
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set float16 to true to load the model in float16 for BetterTransformers"
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")