Compare commits

...

3 Commits

Author SHA1 Message Date
Sunny Liu
ae8738aa87 skip check_datasets_label during debug for grpo 2025-04-17 09:47:14 -04:00
Sunny Liu
ec52561a0c import from filepath if can't import_module 2025-04-17 09:47:14 -04:00
Sunny Liu
eadb16c709 test import-wihtin-import relative path 2025-04-17 09:47:14 -04:00
2 changed files with 37 additions and 11 deletions

View File

@@ -129,17 +129,21 @@ def load_preference_datasets(
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
if cfg.rl == "grpo":
LOG.info("skip check_dataset_labels during debug for grpo")
else:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,

View File

@@ -4,10 +4,24 @@ module for base dataset transform strategies
import importlib
import logging
import sys
LOG = logging.getLogger("axolotl")
def import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Could not create module spec for: {file_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
spec.loader = loader
loader.exec_module(module)
return module
def load(strategy, cfg, module_base=None, **kwargs):
try:
if len(strategy.split(".")) == 1:
@@ -22,7 +36,15 @@ def load(strategy, cfg, module_base=None, **kwargs):
module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-2]
except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1])
try:
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
module_name = strategy.split(".")[-2]
mod = import_from_path(module_name, file_path)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except FileNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1])
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)