diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 56a70ec48..292159bb8 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1066,23 +1066,23 @@ class ModelCompatibilityValidationMixin: raise ValueError("gradient_checkpointing is not supported for MPT models") return self - @model_validator(mode="after") - def check_offload_grad_checkpointing(self): - if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": - LOG.warning( - "`unsloth` is deprecated for gradient_checkpointing, use `offload`" - ) - self.gradient_checkpointing = "offload" - return self - @model_validator(mode="after") def check_gradient_checkpointing_w_offload(self): if self.gradient_checkpointing == "offload": LOG.warning( - "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`" + "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`" ) self.gradient_checkpointing = True - self.activation_offloading = True + if self.adapter and "lora" in self.adapter: + LOG.warning( + "offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation." + ) + self.activation_offloading = "legacy" + else: + LOG.warning( + "`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`" + ) + self.activation_offloading = True if self.gradient_checkpointing == "offload_disk": LOG.warning( "`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`" @@ -1091,6 +1091,19 @@ class ModelCompatibilityValidationMixin: self.activation_offloading = "disk" return self + @model_validator(mode="after") + def check_activation_offloading_w_lora(self): + if ( + self.activation_offloading is True + and self.adapter + and "lora" in self.adapter + ): + LOG.warning( + "activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`" + ) + self.activation_offloading = "legacy" + return self + @model_validator(mode="after") def check_activation_offloading_wo_gc(self): if self.activation_offloading and not self.gradient_checkpointing: diff --git a/tests/utils/schemas/validation/test_activation_offloading.py b/tests/utils/schemas/validation/test_activation_offloading.py new file mode 100644 index 000000000..92ac8f45c --- /dev/null +++ b/tests/utils/schemas/validation/test_activation_offloading.py @@ -0,0 +1,91 @@ +"""Test for config validation for activation offloading.""" + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestActivationOffloading: + """ + Test cases for activation offloading schema validation + """ + + def test_gc_converts_offload_wo_lora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing="offload", + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading is True + + def test_gc_converts_offload_w_lora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing="offload", + adapter="lora", + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading == "legacy" + + def test_gc_converts_offload_w_qlora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing="offload", + adapter="qlora", + load_in_4bit=True, + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading == "legacy" + + def test_ac_impl_changes_w_lora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing=True, + activation_offloading=True, + adapter="lora", + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading == "legacy" + + def test_ac_impl_changes_w_qlora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing=True, + activation_offloading=True, + adapter="qlora", + load_in_4bit=True, + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading == "legacy" + + def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing=True, + activation_offloading=True, + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading is True