lint
This commit is contained in:
@@ -120,13 +120,21 @@ def drop_long_rl_seq(
|
|||||||
# Truncate the chosen and rejected responses if needed
|
# Truncate the chosen and rejected responses if needed
|
||||||
if len_chosen > max_response_len:
|
if len_chosen > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
# Tokenize, truncate, and decode
|
||||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][:max_response_len]
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||||
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
|
"input_ids"
|
||||||
|
][:max_response_len]
|
||||||
|
sample["chosen"] = tokenizer.decode(
|
||||||
|
chosen_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
if len_rejected > max_response_len:
|
if len_rejected > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
# Tokenize, truncate, and decode
|
||||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)["input_ids"][:max_response_len]
|
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||||
sample["rejected"] = tokenizer.decode(rejected_tokens, skip_special_tokens=True)
|
"input_ids"
|
||||||
|
][:max_response_len]
|
||||||
|
sample["rejected"] = tokenizer.decode(
|
||||||
|
rejected_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -159,8 +167,12 @@ def drop_long_rl_seq(
|
|||||||
# Truncate the completion if needed
|
# Truncate the completion if needed
|
||||||
if len_completion > max_completion_len:
|
if len_completion > max_completion_len:
|
||||||
# Tokenize, truncate, and decode
|
# Tokenize, truncate, and decode
|
||||||
completion_tokens = tokenizer(completion, add_special_tokens=False)["input_ids"][:max_completion_len]
|
completion_tokens = tokenizer(completion, add_special_tokens=False)[
|
||||||
sample["completion"] = tokenizer.decode(completion_tokens, skip_special_tokens=True)
|
"input_ids"
|
||||||
|
][:max_completion_len]
|
||||||
|
sample["completion"] = tokenizer.decode(
|
||||||
|
completion_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -245,7 +257,9 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Truncating Long Sequences",
|
desc="Truncating Long Sequences",
|
||||||
)
|
)
|
||||||
LOG.info(f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens")
|
LOG.info(
|
||||||
|
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
@@ -188,7 +188,9 @@ class AxolotlInputConfig(
|
|||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
excess_token_handling: Literal["drop", "truncate"] = Field(
|
excess_token_handling: Literal["drop", "truncate"] = Field(
|
||||||
default="drop",
|
default="drop",
|
||||||
json_schema_extra={"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"},
|
json_schema_extra={
|
||||||
|
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
|
|||||||
@@ -235,7 +235,9 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, handling="drop"):
|
def truncate_or_drop_long_seq(
|
||||||
|
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
|
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
|
||||||
or too short (< min_sequence_len).
|
or too short (< min_sequence_len).
|
||||||
@@ -292,7 +294,9 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
# For drop mode or if the sample doesn't exceed max length
|
# For drop mode or if the sample doesn't exceed max length
|
||||||
return min_sequence_len <= length <= sequence_len if handling == "drop" else sample
|
return (
|
||||||
|
min_sequence_len <= length <= sequence_len if handling == "drop" else sample
|
||||||
|
)
|
||||||
|
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
@@ -316,7 +320,9 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han
|
|||||||
|
|
||||||
# Also truncate attention_mask if present
|
# Also truncate attention_mask if present
|
||||||
if "attention_mask" in sample:
|
if "attention_mask" in sample:
|
||||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
sample["attention_mask"][i] = sample["attention_mask"][i][
|
||||||
|
:sequence_len
|
||||||
|
]
|
||||||
|
|
||||||
# Also truncate labels if present
|
# Also truncate labels if present
|
||||||
if "labels" in sample:
|
if "labels" in sample:
|
||||||
@@ -468,7 +474,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
|
|
||||||
|
|
||||||
def process_pretraining_datasets_for_packing(
|
def process_pretraining_datasets_for_packing(
|
||||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False, handling="drop"
|
train_dataset,
|
||||||
|
sequence_len,
|
||||||
|
skip_position_ids=True,
|
||||||
|
drop_attention_mask=False,
|
||||||
|
handling="drop",
|
||||||
):
|
):
|
||||||
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
|
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
|
||||||
|
|
||||||
@@ -480,7 +490,9 @@ def process_pretraining_datasets_for_packing(
|
|||||||
load_from_cache_file=False,
|
load_from_cache_file=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
truncate_fn = partial(truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling)
|
truncate_fn = partial(
|
||||||
|
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling
|
||||||
|
)
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
truncate_fn,
|
truncate_fn,
|
||||||
desc="Truncating Long Sequences",
|
desc="Truncating Long Sequences",
|
||||||
|
|||||||
Reference in New Issue
Block a user