don't automatically enable lora kernels for RL training (#2600)
This commit is contained in:
@@ -1319,6 +1319,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_auto_enable_lora_kernels(cls, data):
|
def check_auto_enable_lora_kernels(cls, data):
|
||||||
# Only proceed if using LoRA or QLoRA adapter
|
# Only proceed if using LoRA or QLoRA adapter
|
||||||
|
if data.get("rl"):
|
||||||
|
# RL trainers not tested so don't enable kernels by default
|
||||||
|
return data
|
||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
|
|||||||
Reference in New Issue
Block a user