diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 35bd5fcbb..cf1226175 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -64,15 +64,57 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): tokenizer = load_tokenizer(cfg) ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", ) - if isinstance(data_set, DatasetDict): - data_set = data_set["train"] + return data_set +def drop_long_rl_seq( + sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name +): + if rl in ("dpo", "ipo", "orpo", "simpo"): + if not ( + sample.get("prompt") and sample.get("chosen") and sample.get("rejected") + ): + raise ValueError( + "Prompt, chosen and rejected keys are required for DPO/ORPO datasets" + ) + + prompt = sample["prompt"] + chosen = sample["chosen"] + rejected = sample["rejected"] + + len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) + len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"]) + len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) + + return (len_prompt + len_chosen) <= sequence_len and ( + len_prompt + len_rejected + ) <= sequence_len + + if rl == "kto": + if not (sample.get("prompt") and sample.get("completion")): + raise ValueError("Prompt and completion keys are required for KTO datasets") + + prompt = sample["prompt"] + completion = sample["completion"] + + len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) + len_completion = len( + tokenizer(completion, add_special_tokens=False)["input_ids"] + ) + + return (len_prompt + len_completion) <= sequence_len + + raise ValueError("Unknown RL type") + + def load_prepare_dpo_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] @@ -94,7 +136,7 @@ def load_prepare_dpo_datasets(cfg): ) split_datasets.insert(i, ds) - tokenizer = None + tokenizer = load_tokenizer(cfg) for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] @@ -121,7 +163,28 @@ def load_prepare_dpo_datasets(cfg): # "prompt", "chosen" and "rejected" already preprocessed split_datasets[i] = data_set - return concatenate_datasets(split_datasets) + drop_long = partial( + drop_long_rl_seq, + rl=_cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + + combined_datasets = concatenate_datasets(split_datasets) + combined_datasets = combined_datasets.shuffle(seed=cfg.seed) + + return combined_datasets with zero_first(is_main_process()): train_is_preprocessed = False diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 97c3c6465..139d50110 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -66,28 +66,47 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): def check_rl_example_labels(example, tokenizer, text_only=False): - field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" + field_prompt, field_chosen, field_rejected, field_completion = ( + "prompt", + "chosen", + "rejected", + "completion", + ) input_tokens = example[field_prompt] - labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] + + labels_chosen = example.get(field_chosen) + labels_rejected = example.get(field_rejected) + labels_completion = example.get(field_completion) + + # Create a delimiter based on text_only flag + delimiter = "" if text_only else " " # Process and color each type of token colored_tokens = process_tokens_for_rl_debug( input_tokens, "yellow", tokenizer, text_only ) - colored_chosens = process_tokens_for_rl_debug( - labels_chosen, "green", tokenizer, text_only - ) - colored_rejecteds = process_tokens_for_rl_debug( - labels_rejected, "red", tokenizer, text_only - ) - # Create a delimiter based on text_only flag - delimiter = "" if text_only else " " + # Process tokens + if labels_completion is None: + colored_chosens = process_tokens_for_rl_debug( + labels_chosen, "green", tokenizer, text_only + ) + colored_rejecteds = process_tokens_for_rl_debug( + labels_rejected, "red", tokenizer, text_only + ) + else: + colored_completion = process_tokens_for_rl_debug( + labels_completion, "green", tokenizer, text_only + ) # Logging information LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") - LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") - LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") + + if labels_completion is None: + LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") + LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") + else: + LOG.info(f"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\n\n\n") return delimiter.join(colored_tokens) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2d3a6944f..32e54c9a8 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -203,37 +203,59 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_long, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", ) + dropped = prior_len - len(train_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from train dataset") + if eval_dataset: + prior_len = len(eval_dataset) eval_dataset = eval_dataset.filter( drop_long, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", ) + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from eval dataset") # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 + prior_len = len(train_dataset) train_dataset = train_dataset.filter( drop_no_trainable_tokens, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Drop Samples with Zero Trainable Tokens", ) + dropped = prior_len - len(train_dataset) + if dropped: + LOG.warning( + f"Dropped {dropped} samples with no trainable tokens from train dataset" + ) + if eval_dataset: + prior_len = len(eval_dataset) eval_dataset = eval_dataset.filter( drop_no_trainable_tokens, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Drop Samples with Zero Trainable Tokens", ) + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning( + f"Dropped {dropped} samples with no trainable tokens from eval dataset" + ) if cfg.group_by_length: train_dataset = train_dataset.map( @@ -493,7 +515,7 @@ def prepare_opinionated_env(cfg): def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): - if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: + if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"): trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]