diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ab197304c..64bf48664 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -9,7 +9,7 @@ from transformers import ( TrainerState, TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods - """Callback to save the BatterTransformer wrapped model""" + """Callback to save the BetterTransformer wrapped model""" - def on_save( + def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) + # Save + if ( + args.save_strategy == IntervalStrategy.STEPS + and args.save_steps > 0 + and state.global_step % args.save_steps == 0 + ): + control.should_save = True - model = BetterTransformer.reverse(kwargs["model"]) - model.save_pretrained(checkpoint_folder) + if control.should_save: + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + + # since we're saving here, we don't need the trainer loop to attempt to save too b/c + # the trainer will raise an exception since it can't save a BetterTransformer wrapped model + control.should_save = False return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b7823fea4..59b1dc803 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): callbacks.append(SavePeftModelCallback) if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + logging.info("Setting up SaveBetterTransformerModelCallback.") callbacks.append(SaveBetterTransformerModelCallback) data_collator_kwargs = { diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index abaaba8d0..396036621 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -66,9 +66,10 @@ def validate_config(cfg): ) if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") - if cfg.float16 is not True: + if cfg.float16 is not True and cfg.bloat16 is not True: logging.warning( - "You should probably set float16 to true to load the model in float16 for BetterTransformers" + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" ) if int(torch.__version__.split(".")[0]) < 2: logging.warning("torch>=2.0.0 required")