feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]

* feat(doc): note lora kernel incompat with RLHF

* fix: add validation following comments

* chore: fix typo following suggestion
This commit is contained in:
NanoCode012
2025-05-28 15:48:40 +07:00
committed by GitHub
parent 3e6948be97
commit e33f225434
2 changed files with 19 additions and 2 deletions

View File

@@ -84,6 +84,10 @@ lora_qkv_kernel: true
lora_o_kernel: true
```
::: {.callout-note}
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -1052,7 +1052,7 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
def check_lora_kernel_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
@@ -1060,10 +1060,23 @@ class AxolotlInputConfig(
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernel_rl(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("rl"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):