Compare commits
1 Commits
no-zero-ds
...
fix/kd-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
348409c2ff |
@@ -74,6 +74,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
num_items_in_batch = -1
|
||||||
|
|
||||||
if self.args.kd_zscore_base_temp:
|
if self.args.kd_zscore_base_temp:
|
||||||
loss_kd = topk_kd_loss_with_zscore(
|
loss_kd = topk_kd_loss_with_zscore(
|
||||||
shift_logits,
|
shift_logits,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from axolotl.utils.data.utils import (
|
|||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_local_main_process
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -66,31 +66,32 @@ LOG = logging.getLogger(__name__)
|
|||||||
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
if cfg.test_datasets:
|
with zero_first(is_local_main_process()):
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
if cfg.test_datasets:
|
||||||
tokenizer,
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
cfg,
|
tokenizer,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
cfg,
|
||||||
split="train",
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
processor=processor,
|
split="train",
|
||||||
preprocess_iterable=preprocess_iterable,
|
processor=processor,
|
||||||
)
|
preprocess_iterable=preprocess_iterable,
|
||||||
_, eval_dataset, _ = load_prepare_datasets(
|
)
|
||||||
tokenizer,
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
cfg,
|
tokenizer,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
cfg,
|
||||||
split="test",
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
processor=processor,
|
split="test",
|
||||||
preprocess_iterable=preprocess_iterable,
|
processor=processor,
|
||||||
)
|
preprocess_iterable=preprocess_iterable,
|
||||||
else:
|
)
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
else:
|
||||||
tokenizer,
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
cfg,
|
tokenizer,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
cfg,
|
||||||
processor=processor,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
preprocess_iterable=preprocess_iterable,
|
processor=processor,
|
||||||
)
|
preprocess_iterable=preprocess_iterable,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Load streaming dataset if pretraining_dataset is given
|
# Load streaming dataset if pretraining_dataset is given
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
@@ -271,7 +272,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
LOG.info("Loading raw datasets...")
|
LOG.info("Loading raw datasets...")
|
||||||
if not cfg.is_preprocess:
|
if not cfg.is_preprocess:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Processing datasets during training can lead to VRAM instability. Please use `axolotl preprocess` to prepare your dataset."
|
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.seed:
|
if cfg.seed:
|
||||||
|
|||||||
Reference in New Issue
Block a user