wire up gradient checkpointing for 4bit

This commit is contained in:
Wing Lian
2023-04-28 22:27:33 -04:00
parent 4e705eda6d
commit c0f50d9c61

View File

@@ -28,7 +28,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.load_4bit:
from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing
gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
else:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
# deepspeed
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1: