more packing and dataset optimizations and fixes
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git
|
transformers @ git+https://github.com/huggingface/transformers.git
|
||||||
bitsandbytes>=0.39.0
|
bitsandbytes>=0.39.0
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
accelerate @ git+https://github.com/huggingface/accelerate@b42c65b
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
|
from accelerate import Accelerator
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
@@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import (
|
||||||
|
calculate_total_num_steps,
|
||||||
|
process_datasets_for_packing,
|
||||||
|
setup_trainer,
|
||||||
|
)
|
||||||
from axolotl.utils.validation import validate_config
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -168,6 +173,7 @@ def train(
|
|||||||
prepare_ds_only: bool = False,
|
prepare_ds_only: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
accelerator = Accelerator()
|
||||||
if Path(config).is_dir():
|
if Path(config).is_dir():
|
||||||
config = choose_config(config)
|
config = choose_config(config)
|
||||||
|
|
||||||
@@ -237,6 +243,21 @@ def train(
|
|||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||||
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
|
cfg, train_dataset, eval_dataset
|
||||||
|
)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if not accelerator.is_local_main_process:
|
||||||
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
|
cfg, train_dataset, eval_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset.cleanup_cache_files()
|
||||||
|
eval_dataset.cleanup_cache_files()
|
||||||
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
@@ -286,7 +307,9 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
return
|
return
|
||||||
|
|
||||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
trainer = setup_trainer(
|
||||||
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
@@ -345,7 +368,13 @@ def train(
|
|||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
model.save_pretrained(cfg.output_dir)
|
with model.summon_full_params():
|
||||||
|
model.save_pretrained(
|
||||||
|
cfg.output_dir,
|
||||||
|
is_main_process=trainer.accelerator.is_main_process,
|
||||||
|
save_function=trainer.accelerator.save,
|
||||||
|
state_dict=trainer.accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|||||||
@@ -128,8 +128,8 @@ def xformers_forward(
|
|||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_bias=attention_mask,
|
# attn_bias=attention_mask,
|
||||||
# attn_bias=xformers.ops.LowerTriangularMask(),
|
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||||
)
|
)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -109,13 +109,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
ds = load_dataset(
|
try:
|
||||||
d.path,
|
ds = load_from_disk(d.path)
|
||||||
name=d.name,
|
except FileNotFoundError:
|
||||||
data_files=d.data_files,
|
ds = load_dataset(
|
||||||
streaming=False,
|
d.path,
|
||||||
split=None,
|
name=d.name,
|
||||||
)
|
data_files=d.data_files,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json",
|
"json",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -63,6 +62,14 @@ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
|||||||
def allocate(
|
def allocate(
|
||||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
:param lengths: array of lengths of each sample
|
||||||
|
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||||
|
:param rank: rank for this process
|
||||||
|
:param c: length of tokens per batch
|
||||||
|
:param n: number of ranks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
# Dynamic batch allocator, similar to Multifit
|
# Dynamic batch allocator, similar to Multifit
|
||||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||||
@@ -98,7 +105,7 @@ def allocate(
|
|||||||
result.append(batch[rank])
|
result.append(batch[rank])
|
||||||
# add total seqs for all ranks
|
# add total seqs for all ranks
|
||||||
result_totseqs.append(tot_seqs)
|
result_totseqs.append(tot_seqs)
|
||||||
|
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
return result, result_totseqs, s, len(result) * c * n
|
return result, result_totseqs, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
@@ -129,6 +136,7 @@ class MultipackDistributedDataloader:
|
|||||||
sampler: Union[Sampler, DistributedSampler] = None,
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
|
device_count: int = 1,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -152,6 +160,7 @@ class MultipackDistributedDataloader:
|
|||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.device_count = device_count
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
LOG.info("generating packed batches")
|
LOG.info("generating packed batches")
|
||||||
@@ -233,7 +242,7 @@ class MultipackDistributedDataloader:
|
|||||||
indices = range(0, len(self.dataset))
|
indices = range(0, len(self.dataset))
|
||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_sum = np.cumsum(lengths)[-1]
|
lengths_sum = np.cumsum(lengths)[-1]
|
||||||
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
|
lengths_sum_per_device = lengths_sum // self.device_count
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
|||||||
@@ -238,7 +238,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
@@ -273,7 +272,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -304,7 +302,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -318,7 +315,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -12,10 +13,10 @@ from typing import Optional, Union
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset, set_caching_enabled
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
@@ -158,20 +159,22 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.world_size > 1 and self.args.sample_packing:
|
# if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
return DistributedSampler(
|
# return DistributedSampler(
|
||||||
self.train_dataset,
|
# self.train_dataset,
|
||||||
num_replicas=self.args.world_size,
|
# num_replicas=self.args.world_size,
|
||||||
rank=self.args.process_index,
|
# rank=self.args.process_index,
|
||||||
seed=self.args.seed,
|
# seed=self.args.seed,
|
||||||
)
|
# )
|
||||||
return super()._get_train_sampler()
|
# return super()._get_train_sampler()
|
||||||
|
|
||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
|
# If set to True, the dataloader prepared is only iterated through on the
|
||||||
|
# main process and then the batches are split and broadcast to each process
|
||||||
|
self.accelerator.dispatch_batches = True
|
||||||
return self.accelerator.prepare(
|
return self.accelerator.prepare(
|
||||||
MultipackDistributedDataloader(
|
MultipackDistributedDataloader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
@@ -181,6 +184,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
|
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -193,6 +197,9 @@ class AxolotlTrainer(Trainer):
|
|||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
# If set to True, the datalaoder prepared is only iterated through on the
|
||||||
|
# main process and then the batches are split and broadcast to each process
|
||||||
|
self.accelerator.dispatch_batches = True
|
||||||
return self.accelerator.prepare(
|
return self.accelerator.prepare(
|
||||||
MultipackDistributedDataloader(
|
MultipackDistributedDataloader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
@@ -202,6 +209,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
|
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
@@ -253,7 +261,16 @@ def drop_long_seq(sample, sequence_len=2048):
|
|||||||
return len(sample["input_ids"]) <= sequence_len
|
return len(sample["input_ids"]) <= sequence_len
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
@contextmanager
|
||||||
|
def disable_datasets_caching():
|
||||||
|
try:
|
||||||
|
set_caching_enabled(False)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
set_caching_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
train_dataset = train_dataset.filter(drop_long).map(
|
train_dataset = train_dataset.filter(drop_long).map(
|
||||||
@@ -263,12 +280,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
eval_dataset = eval_dataset.filter(drop_long).map(
|
eval_dataset = eval_dataset.filter(drop_long).map(
|
||||||
add_position_ids, num_proc=os.cpu_count()
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
)
|
)
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# we have to drop anything longer then sequence len otherwise
|
||||||
|
# flash attention with position ids fails
|
||||||
|
total_num_tokens = (
|
||||||
|
cfg.total_num_tokens
|
||||||
|
if cfg.total_num_tokens
|
||||||
|
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
|
)
|
||||||
|
if not cfg.total_num_tokens:
|
||||||
|
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||||
|
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_tokens = (
|
|
||||||
cfg.total_num_tokens
|
|
||||||
if cfg.total_num_tokens
|
|
||||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
|
||||||
)
|
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
# match count to len est in dataloader
|
# match count to len est in dataloader
|
||||||
(
|
(
|
||||||
@@ -300,8 +327,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||||
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader)
|
data_loader_len = len(data_loader)
|
||||||
|
actual_eff = data_loader.efficiency()
|
||||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(
|
||||||
@@ -311,10 +340,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||||
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
warmup_steps = (
|
warmup_steps = (
|
||||||
cfg.warmup_steps
|
cfg.warmup_steps
|
||||||
if cfg.warmup_steps is not None
|
if cfg.warmup_steps is not None
|
||||||
|
|||||||
@@ -110,6 +110,17 @@ def validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.sdp_attention:
|
||||||
|
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.xformers_attention:
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
Reference in New Issue
Block a user