chore(callback): Remove old peft saving code (#510)
This commit is contained in:
@@ -43,29 +43,6 @@ LOG = logging.getLogger("axolotl.callbacks")
|
|||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
||||||
"""Callback to save the PEFT adapter"""
|
|
||||||
|
|
||||||
def on_save(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
checkpoint_folder = os.path.join(
|
|
||||||
args.output_dir,
|
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
|
||||||
)
|
|
||||||
|
|
||||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
|
||||||
kwargs["model"].save_pretrained(
|
|
||||||
peft_model_path, save_safetensors=args.save_safetensors
|
|
||||||
)
|
|
||||||
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class EvalFirstStepCallback(
|
class EvalFirstStepCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from axolotl.utils.callbacks import (
|
|||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
@@ -711,12 +710,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
callbacks.append(ReLoRACallback(cfg))
|
callbacks.append(ReLoRACallback(cfg))
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter in [
|
|
||||||
"lora",
|
|
||||||
"qlora",
|
|
||||||
]: # only save in rank 0
|
|
||||||
callbacks.append(SavePeftModelCallback)
|
|
||||||
|
|
||||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||||
callbacks.append(SaveBetterTransformerModelCallback)
|
callbacks.append(SaveBetterTransformerModelCallback)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user