diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 1e8682235..ae940cd52 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -18,6 +18,7 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_tokenizer +from axolotl.utils.schemas.enums import RLType LOG = logging.getLogger("axolotl") @@ -86,7 +87,7 @@ def drop_long_rl_seq( ): result = None - if rl in ("dpo", "ipo", "orpo", "simpo"): + if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -160,7 +161,7 @@ def drop_long_rl_seq( len_prompt + len_rejected ) <= sequence_len - elif rl == "kto": + elif rl == RLType.KTO: if not (sample.get("prompt") and sample.get("completion")): raise ValueError("Prompt and completion keys are required for KTO datasets") @@ -197,7 +198,7 @@ def drop_long_rl_seq( else: # handling == "drop" result = (len_prompt + len_completion) <= sequence_len - elif rl == "grpo": + elif rl == RLType.GRPO: # GRPO doesn't involve sequence length checks in the same way? # The original code returned True for drop. What should it return for truncate? # Let's assume for now it always passes.