activation offloading with cuda streams doesn't work with LoRA (#2927)
This commit is contained in:
@@ -1066,23 +1066,23 @@ class ModelCompatibilityValidationMixin:
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_offload_grad_checkpointing(self):
|
|
||||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
|
||||||
LOG.warning(
|
|
||||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
|
||||||
)
|
|
||||||
self.gradient_checkpointing = "offload"
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_gradient_checkpointing_w_offload(self):
|
def check_gradient_checkpointing_w_offload(self):
|
||||||
if self.gradient_checkpointing == "offload":
|
if self.gradient_checkpointing == "offload":
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
|
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
|
||||||
)
|
)
|
||||||
self.gradient_checkpointing = True
|
self.gradient_checkpointing = True
|
||||||
self.activation_offloading = True
|
if self.adapter and "lora" in self.adapter:
|
||||||
|
LOG.warning(
|
||||||
|
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
|
||||||
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = True
|
||||||
if self.gradient_checkpointing == "offload_disk":
|
if self.gradient_checkpointing == "offload_disk":
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
||||||
@@ -1091,6 +1091,19 @@ class ModelCompatibilityValidationMixin:
|
|||||||
self.activation_offloading = "disk"
|
self.activation_offloading = "disk"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_activation_offloading_w_lora(self):
|
||||||
|
if (
|
||||||
|
self.activation_offloading is True
|
||||||
|
and self.adapter
|
||||||
|
and "lora" in self.adapter
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_activation_offloading_wo_gc(self):
|
def check_activation_offloading_wo_gc(self):
|
||||||
if self.activation_offloading and not self.gradient_checkpointing:
|
if self.activation_offloading and not self.gradient_checkpointing:
|
||||||
|
|||||||
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Test for config validation for activation offloading."""
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivationOffloading:
|
||||||
|
"""
|
||||||
|
Test cases for activation offloading schema validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading is True
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading is True
|
||||||
Reference in New Issue
Block a user