activation offloading with cuda streams doesn't work with LoRA (#2927)
This commit is contained in:
@@ -1066,23 +1066,23 @@ class ModelCompatibilityValidationMixin:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_offload_grad_checkpointing(self):
|
||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||
LOG.warning(
|
||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||
)
|
||||
self.gradient_checkpointing = "offload"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_gradient_checkpointing_w_offload(self):
|
||||
if self.gradient_checkpointing == "offload":
|
||||
LOG.warning(
|
||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
|
||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
|
||||
)
|
||||
self.gradient_checkpointing = True
|
||||
self.activation_offloading = True
|
||||
if self.adapter and "lora" in self.adapter:
|
||||
LOG.warning(
|
||||
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
|
||||
)
|
||||
self.activation_offloading = "legacy"
|
||||
else:
|
||||
LOG.warning(
|
||||
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
|
||||
)
|
||||
self.activation_offloading = True
|
||||
if self.gradient_checkpointing == "offload_disk":
|
||||
LOG.warning(
|
||||
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
||||
@@ -1091,6 +1091,19 @@ class ModelCompatibilityValidationMixin:
|
||||
self.activation_offloading = "disk"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_activation_offloading_w_lora(self):
|
||||
if (
|
||||
self.activation_offloading is True
|
||||
and self.adapter
|
||||
and "lora" in self.adapter
|
||||
):
|
||||
LOG.warning(
|
||||
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
|
||||
)
|
||||
self.activation_offloading = "legacy"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_activation_offloading_wo_gc(self):
|
||||
if self.activation_offloading and not self.gradient_checkpointing:
|
||||
|
||||
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test for config validation for activation offloading."""
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestActivationOffloading:
|
||||
"""
|
||||
Test cases for activation offloading schema validation
|
||||
"""
|
||||
|
||||
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing="offload",
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading is True
|
||||
|
||||
def test_gc_converts_offload_w_lora(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing="offload",
|
||||
adapter="lora",
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading == "legacy"
|
||||
|
||||
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing="offload",
|
||||
adapter="qlora",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading == "legacy"
|
||||
|
||||
def test_ac_impl_changes_w_lora(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing=True,
|
||||
activation_offloading=True,
|
||||
adapter="lora",
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading == "legacy"
|
||||
|
||||
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing=True,
|
||||
activation_offloading=True,
|
||||
adapter="qlora",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading == "legacy"
|
||||
|
||||
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
gradient_checkpointing=True,
|
||||
activation_offloading=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.gradient_checkpointing is True
|
||||
assert cfg.activation_offloading is True
|
||||
Reference in New Issue
Block a user