fix merge conflicts

This commit is contained in:
mhenrhcsen
2025-05-14 13:33:42 +02:00
parent fea6649518
commit 9c5b8da22f

View File

@@ -18,6 +18,7 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger("axolotl")
@@ -86,7 +87,7 @@ def drop_long_rl_seq(
):
result = None
if rl in ("dpo", "ipo", "orpo", "simpo"):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -160,7 +161,7 @@ def drop_long_rl_seq(
len_prompt + len_rejected
) <= sequence_len
elif rl == "kto":
elif rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -197,7 +198,7 @@ def drop_long_rl_seq(
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
elif rl == "grpo":
elif rl == RLType.GRPO:
# GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.