activation offloading with cuda streams doesn't work with LoRA (#2927)

This commit is contained in:
Wing Lian
2025-07-16 11:59:20 -04:00
committed by GitHub
parent 2c408b5c5e
commit 36cbe13d18
2 changed files with 115 additions and 11 deletions

View File

@@ -1066,23 +1066,23 @@ class ModelCompatibilityValidationMixin:
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_gradient_checkpointing_w_offload(self):
if self.gradient_checkpointing == "offload":
LOG.warning(
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
)
self.gradient_checkpointing = True
self.activation_offloading = True
if self.adapter and "lora" in self.adapter:
LOG.warning(
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
)
self.activation_offloading = "legacy"
else:
LOG.warning(
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
)
self.activation_offloading = True
if self.gradient_checkpointing == "offload_disk":
LOG.warning(
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
@@ -1091,6 +1091,19 @@ class ModelCompatibilityValidationMixin:
self.activation_offloading = "disk"
return self
@model_validator(mode="after")
def check_activation_offloading_w_lora(self):
if (
self.activation_offloading is True
and self.adapter
and "lora" in self.adapter
):
LOG.warning(
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
)
self.activation_offloading = "legacy"
return self
@model_validator(mode="after")
def check_activation_offloading_wo_gc(self):
if self.activation_offloading and not self.gradient_checkpointing:

View File

@@ -0,0 +1,91 @@
"""Test for config validation for activation offloading."""
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestActivationOffloading:
"""
Test cases for activation offloading schema validation
"""
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading is True
def test_gc_converts_offload_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading is True