fix(config): passing gradient_checkpoint_kwargs (#1412)
* fix(config): change default use_reentrant to true * Update trainer_builder.py * fix: make sure to pass kwargs to enable checkpoint * chore: lint
This commit is contained in:
@@ -859,7 +859,7 @@ group_by_length: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
# gradient_checkpointing_kwargs:
|
# gradient_checkpointing_kwargs:
|
||||||
# use_reentrant: false
|
# use_reentrant: true
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
# Stop training after this many evaluation losses have increased in a row
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
|
|||||||
@@ -970,10 +970,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_checkpointing_kwargs"
|
"gradient_checkpointing_kwargs"
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
else:
|
|
||||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
|
||||||
"use_reentrant": False
|
|
||||||
}
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
|
|||||||
@@ -888,7 +888,9 @@ def load_model(
|
|||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(
|
||||||
|
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
cfg.load_in_8bit or cfg.load_in_4bit
|
cfg.load_in_8bit or cfg.load_in_4bit
|
||||||
) and not skip_prepare_model_for_kbit_training:
|
) and not skip_prepare_model_for_kbit_training:
|
||||||
|
|||||||
Reference in New Issue
Block a user