activation offloading with cuda streams doesn't work with LoRA (#2927)

This commit is contained in:
Wing Lian
2025-07-16 11:59:20 -04:00
committed by GitHub
parent 2c408b5c5e
commit 36cbe13d18
2 changed files with 115 additions and 11 deletions

View File

@@ -0,0 +1,91 @@
"""Test for config validation for activation offloading."""
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestActivationOffloading:
"""
Test cases for activation offloading schema validation
"""
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading is True
def test_gc_converts_offload_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading is True