fix merge conflicts
This commit is contained in:
@@ -18,6 +18,7 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -86,7 +87,7 @@ def drop_long_rl_seq(
|
||||
):
|
||||
result = None
|
||||
|
||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
@@ -160,7 +161,7 @@ def drop_long_rl_seq(
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
elif rl == "kto":
|
||||
elif rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
@@ -197,7 +198,7 @@ def drop_long_rl_seq(
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
elif rl == "grpo":
|
||||
elif rl == RLType.GRPO:
|
||||
# GRPO doesn't involve sequence length checks in the same way?
|
||||
# The original code returned True for drop. What should it return for truncate?
|
||||
# Let's assume for now it always passes.
|
||||
|
||||
Reference in New Issue
Block a user