more fixes and optimizations
This commit is contained in:
@@ -253,9 +253,7 @@ def train(
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
eval_dataset.cleanup_cache_files()
|
||||
accelerator.wait_for_everyone()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
@@ -382,6 +380,10 @@ def train(
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
if eval_dataset:
|
||||
eval_dataset.cleanup_cache_files()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
@@ -7,6 +8,7 @@ from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -37,6 +39,7 @@ from axolotl.prompters import (
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
accelerator = Accelerator()
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
@@ -416,7 +419,51 @@ def load_prepare_datasets(
|
||||
)
|
||||
|
||||
if cfg.val_set_size:
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ cfg.seed
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ cfg.seed
|
||||
)
|
||||
train_fingerprint = hashlib.md5(
|
||||
to_hash_train.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
test_fingerprint = hashlib.md5(
|
||||
to_hash_test.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
|
||||
@@ -263,6 +263,15 @@ class MultipackDistributedDataloader:
|
||||
def __len__(self):
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
# if not self.eff_total_used:
|
||||
# batches, _ = self.generate_batches(set_stats=True)
|
||||
# LOG.info(
|
||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
# f"actual packing efficiency: {self.efficiency()}"
|
||||
# )
|
||||
return max(1, self._len_est())
|
||||
|
||||
def len_w_stats(self):
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
|
||||
@@ -175,7 +175,7 @@ class AxolotlTrainer(Trainer):
|
||||
# If set to True, the dataloader prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare(
|
||||
return self.accelerator.prepare_data_loader(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
@@ -200,7 +200,7 @@ class AxolotlTrainer(Trainer):
|
||||
# If set to True, the datalaoder prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare(
|
||||
return self.accelerator.prepare_data_loader(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
@@ -305,6 +305,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
// cfg.batch_size
|
||||
// int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
- 1
|
||||
)
|
||||
@@ -327,17 +328,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.micro_batch_size
|
||||
* cfg.num_epochs
|
||||
/ cfg.batch_size
|
||||
// cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(
|
||||
|
||||
Reference in New Issue
Block a user