more fixes and optimizations

This commit is contained in:
Wing Lian
2023-08-08 03:16:00 -04:00
parent 1162b93b6b
commit 4f7c04bae0
4 changed files with 68 additions and 10 deletions

View File

@@ -253,9 +253,7 @@ def train(
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
train_dataset.cleanup_cache_files()
eval_dataset.cleanup_cache_files()
accelerator.wait_for_everyone()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
@@ -382,6 +380,10 @@ def train(
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
train_dataset.cleanup_cache_files()
if eval_dataset:
eval_dataset.cleanup_cache_files()
if __name__ == "__main__":
fire.Fire(train)

View File

@@ -1,5 +1,6 @@
"""Module containing data utilities"""
import functools
import hashlib
import itertools
import logging
from hashlib import md5
@@ -7,6 +8,7 @@ from pathlib import Path
from typing import List, Tuple, Union
import torch
from accelerate import Accelerator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -37,6 +39,7 @@ from axolotl.prompters import (
)
LOG = logging.getLogger("axolotl")
accelerator = Accelerator()
def load_tokenized_prepared_datasets(
@@ -416,7 +419,51 @@ def load_prepare_datasets(
)
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ cfg.seed
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ cfg.seed
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
if accelerator.is_local_main_process:
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
accelerator.wait_for_everyone()
if not accelerator.is_local_main_process:
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
accelerator.wait_for_everyone()
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:

View File

@@ -263,6 +263,15 @@ class MultipackDistributedDataloader:
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(

View File

@@ -175,7 +175,7 @@ class AxolotlTrainer(Trainer):
# If set to True, the dataloader prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare(
return self.accelerator.prepare_data_loader(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
@@ -200,7 +200,7 @@ class AxolotlTrainer(Trainer):
# If set to True, the datalaoder prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare(
return self.accelerator.prepare_data_loader(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
@@ -305,6 +305,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
/ cfg.sample_packing_eff_est
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)
@@ -327,17 +328,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = len(data_loader)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.ceil(
math.floor(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
/ cfg.batch_size
// cfg.batch_size
)
)
LOG.info(