diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd index 56d56e9fc..7cdf53975 100644 --- a/docs/lora_optims.qmd +++ b/docs/lora_optims.qmd @@ -84,6 +84,10 @@ lora_qkv_kernel: true lora_o_kernel: true ``` +::: {.callout-note} +Currently, LoRA kernels are not supported for RLHF training, only SFT. +::: + ## Requirements - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index cc5f54ac4..e68185323 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1052,7 +1052,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod - def check_lora_8bit(cls, data): + def check_lora_kernel_8bit(cls, data): if ( data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") @@ -1060,10 +1060,23 @@ class AxolotlInputConfig( ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( - "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" ) return data + @model_validator(mode="before") + @classmethod + def check_lora_kernel_rl(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("rl"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment." + ) + return data + @model_validator(mode="before") @classmethod def check_lora_axolotl_unsloth(cls, data):