Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency.
This commit is contained in:
@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal, List, Union
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from transformers import PreTrainedTokenizer
|
||||
@@ -120,6 +120,8 @@ def _map_dataset(
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||
"""
|
||||
Backward-compatibility wrapper for legacy imports in tests.
|
||||
@@ -128,7 +130,6 @@ def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
|
||||
|
||||
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
@@ -163,8 +164,8 @@ def _drop_long_sequences(
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling == "truncate":
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
@@ -176,13 +177,12 @@ def _drop_long_sequences(
|
||||
|
||||
if max_response_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
||||
# Keep sample shape for map(), but log a warning. A subsequent filter will drop it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; will be dropped post-truncation",
|
||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = sample
|
||||
result = False
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
@@ -220,8 +220,8 @@ def _drop_long_sequences(
|
||||
)
|
||||
|
||||
# Truncate first
|
||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling == "truncate":
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
result = sample
|
||||
@@ -232,11 +232,11 @@ def _drop_long_sequences(
|
||||
if max_completion_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; will be dropped post-truncation",
|
||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = sample
|
||||
result = False
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
if len_completion > max_completion_len:
|
||||
@@ -256,7 +256,7 @@ def _drop_long_sequences(
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
return result
|
||||
return bool(result)
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
@@ -296,14 +296,9 @@ def load_prepare_preference_datasets(cfg):
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
# Legacy shim preserved for backward compatibility; no-op in new flow
|
||||
def load_split(dataset_cfgs, _cfg): # noqa: F811
|
||||
return None
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import pymongo
|
||||
|
||||
MONGO_URI = "mongodb://root:9AsYmXYKmYLHcNsShmCb3L5DZMXH77rQ9GBRxm0HKownNWLwdzH9dW7zhPG9mpuR@46.4.101.229:8281/?directConnection=true"
|
||||
COLLECTION_NAME = "tts_data"
|
||||
|
||||
client = pymongo.MongoClient(MONGO_URI)
|
||||
db = client["tts_data"]
|
||||
collection = db[COLLECTION_NAME]
|
||||
|
||||
# Get all documents from the collection that does not have a "transcription" field
|
||||
documents = collection.find({"transcription": {"$exists": False}})
|
||||
|
||||
for document in documents:
|
||||
print(document)
|
||||
break
|
||||
Reference in New Issue
Block a user