Compare commits

..

1 Commits

Author SHA1 Message Date
NanoCode012
348409c2ff fix: num_items_in_batch wrong type in kd trainer loss 2025-05-20 16:56:24 +07:00
2 changed files with 31 additions and 27 deletions

View File

@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if num_items_in_batch is None:
num_items_in_batch = -1
if self.args.kd_zscore_base_temp: if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore( loss_kd = topk_kd_loss_with_zscore(
shift_logits, shift_logits,

View File

@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
retry_on_request_exceptions, retry_on_request_exceptions,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -66,31 +66,32 @@ LOG = logging.getLogger(__name__)
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
if cfg.test_datasets: with zero_first(is_local_main_process()):
train_dataset, _, prompters = load_prepare_datasets( if cfg.test_datasets:
tokenizer, train_dataset, _, prompters = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
split="train", DEFAULT_DATASET_PREPARED_PATH,
processor=processor, split="train",
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
_, eval_dataset, _ = load_prepare_datasets( )
tokenizer, _, eval_dataset, _ = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
split="test", DEFAULT_DATASET_PREPARED_PATH,
processor=processor, split="test",
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
else: )
train_dataset, eval_dataset, prompters = load_prepare_datasets( else:
tokenizer, train_dataset, eval_dataset, prompters = load_prepare_datasets(
cfg, tokenizer,
DEFAULT_DATASET_PREPARED_PATH, cfg,
processor=processor, DEFAULT_DATASET_PREPARED_PATH,
preprocess_iterable=preprocess_iterable, processor=processor,
) preprocess_iterable=preprocess_iterable,
)
else: else:
# Load streaming dataset if pretraining_dataset is given # Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
@@ -271,7 +272,7 @@ def load_tokenized_prepared_datasets(
LOG.info("Loading raw datasets...") LOG.info("Loading raw datasets...")
if not cfg.is_preprocess: if not cfg.is_preprocess:
LOG.warning( LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset." "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
) )
if cfg.seed: if cfg.seed: