feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]
* feat(doc): note lora kernel incompat with RLHF * fix: add validation following comments * chore: fix typo following suggestion
This commit is contained in:
@@ -84,6 +84,10 @@ lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
def check_lora_kernel_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
@@ -1060,10 +1060,23 @@ class AxolotlInputConfig(
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernel_rl(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("rl"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user