fix bettertransformers save, force it to skip after saving correctly in callback
This commit is contained in:
@@ -9,7 +9,7 @@ from transformers import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
@@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|||||||
class SaveBetterTransformerModelCallback(
|
class SaveBetterTransformerModelCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods
|
): # pylint: disable=too-few-public-methods
|
||||||
"""Callback to save the BatterTransformer wrapped model"""
|
"""Callback to save the BetterTransformer wrapped model"""
|
||||||
|
|
||||||
def on_save(
|
def on_step_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
state: TrainerState,
|
state: TrainerState,
|
||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
checkpoint_folder = os.path.join(
|
# Save
|
||||||
args.output_dir,
|
if (
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
args.save_strategy == IntervalStrategy.STEPS
|
||||||
)
|
and args.save_steps > 0
|
||||||
|
and state.global_step % args.save_steps == 0
|
||||||
|
):
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
model = BetterTransformer.reverse(kwargs["model"])
|
if control.should_save:
|
||||||
model.save_pretrained(checkpoint_folder)
|
checkpoint_folder = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = BetterTransformer.reverse(kwargs["model"])
|
||||||
|
model.save_pretrained(checkpoint_folder)
|
||||||
|
|
||||||
|
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||||
|
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||||
|
control.should_save = False
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
callbacks.append(SavePeftModelCallback)
|
callbacks.append(SavePeftModelCallback)
|
||||||
|
|
||||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||||
|
logging.info("Setting up SaveBetterTransformerModelCallback.")
|
||||||
callbacks.append(SaveBetterTransformerModelCallback)
|
callbacks.append(SaveBetterTransformerModelCallback)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
|
|||||||
@@ -66,9 +66,10 @@ def validate_config(cfg):
|
|||||||
)
|
)
|
||||||
if cfg.fp16 or cfg.bf16:
|
if cfg.fp16 or cfg.bf16:
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
raise ValueError("AMP is not supported with BetterTransformer")
|
||||||
if cfg.float16 is not True:
|
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"You should probably set float16 to true to load the model in float16 for BetterTransformers"
|
"You should probably set bfloat16 or float16 to true to "
|
||||||
|
"load the model in float16 for BetterTransformers"
|
||||||
)
|
)
|
||||||
if int(torch.__version__.split(".")[0]) < 2:
|
if int(torch.__version__.split(".")[0]) < 2:
|
||||||
logging.warning("torch>=2.0.0 required")
|
logging.warning("torch>=2.0.0 required")
|
||||||
|
|||||||
Reference in New Issue
Block a user