92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
"""Test for config validation for activation offloading."""
|
|
|
|
from axolotl.utils.config import validate_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class TestActivationOffloading:
|
|
"""
|
|
Test cases for activation offloading schema validation
|
|
"""
|
|
|
|
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing="offload",
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading is True
|
|
|
|
def test_gc_converts_offload_w_lora(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing="offload",
|
|
adapter="lora",
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading == "legacy"
|
|
|
|
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing="offload",
|
|
adapter="qlora",
|
|
load_in_4bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading == "legacy"
|
|
|
|
def test_ac_impl_changes_w_lora(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing=True,
|
|
activation_offloading=True,
|
|
adapter="lora",
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading == "legacy"
|
|
|
|
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing=True,
|
|
activation_offloading=True,
|
|
adapter="qlora",
|
|
load_in_4bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading == "legacy"
|
|
|
|
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
|
cfg = (
|
|
DictDefault(
|
|
gradient_checkpointing=True,
|
|
activation_offloading=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
assert cfg.gradient_checkpointing is True
|
|
assert cfg.activation_offloading is True
|