feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]
* feat(doc): note lora kernel incompat with RLHF * fix: add validation following comments * chore: fix typo following suggestion
This commit is contained in:
@@ -84,6 +84,10 @@ lora_qkv_kernel: true
|
|||||||
lora_o_kernel: true
|
lora_o_kernel: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||||
|
:::
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||||
|
|||||||
@@ -1052,7 +1052,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_8bit(cls, data):
|
def check_lora_kernel_8bit(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
@@ -1060,10 +1060,23 @@ class AxolotlInputConfig(
|
|||||||
):
|
):
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lora_kernel_rl(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
) and data.get("rl"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_axolotl_unsloth(cls, data):
|
def check_lora_axolotl_unsloth(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user