Compare commits
3 Commits
tui
...
ae8738aa87
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae8738aa87 | ||
|
|
ec52561a0c | ||
|
|
eadb16c709 |
@@ -129,17 +129,21 @@ def load_preference_datasets(
|
|||||||
total_num_steps = None
|
total_num_steps = None
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
if cfg.rl == "grpo":
|
||||||
|
LOG.info("skip check_dataset_labels during debug for grpo")
|
||||||
|
else:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||||
check_dataset_labels(
|
|
||||||
train_samples,
|
check_dataset_labels(
|
||||||
tokenizer,
|
train_samples,
|
||||||
num_examples=cli_args.debug_num_examples,
|
tokenizer,
|
||||||
text_only=cli_args.debug_text_only,
|
num_examples=cli_args.debug_num_examples,
|
||||||
rl_mode=True,
|
text_only=cli_args.debug_text_only,
|
||||||
)
|
rl_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
|
|||||||
@@ -4,10 +4,24 @@ module for base dataset transform strategies
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_path(module_name, file_path):
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Could not create module spec for: {file_path}")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
|
||||||
|
spec.loader = loader
|
||||||
|
loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, cfg, module_base=None, **kwargs):
|
def load(strategy, cfg, module_base=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
if len(strategy.split(".")) == 1:
|
if len(strategy.split(".")) == 1:
|
||||||
@@ -22,7 +36,15 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
|||||||
module_base = ".".join(strategy.split(".")[:-2])
|
module_base = ".".join(strategy.split(".")[:-2])
|
||||||
strategy = strategy.split(".")[-2]
|
strategy = strategy.split(".")[-2]
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
try:
|
||||||
|
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
|
||||||
|
module_name = strategy.split(".")[-2]
|
||||||
|
mod = import_from_path(module_name, file_path)
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
return func(cfg, **kwargs)
|
||||||
|
except FileNotFoundError:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
mod = importlib.import_module(strategy, module_base)
|
mod = importlib.import_module(strategy, module_base)
|
||||||
|
|||||||
Reference in New Issue
Block a user