fix(config): passing gradient_checkpoint_kwargs (#1412)
* fix(config): change default use_reentrant to true * Update trainer_builder.py * fix: make sure to pass kwargs to enable checkpoint * chore: lint
This commit is contained in:
@@ -888,7 +888,9 @@ def load_model(
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if (
|
||||
cfg.load_in_8bit or cfg.load_in_4bit
|
||||
) and not skip_prepare_model_for_kbit_training:
|
||||
|
||||
Reference in New Issue
Block a user