more packing and dataset optimizations and fixes
This commit is contained in:
@@ -14,6 +14,7 @@ import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate import Accelerator
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
@@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
setup_trainer,
|
||||
)
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
@@ -168,6 +173,7 @@ def train(
|
||||
prepare_ds_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
accelerator = Accelerator()
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
@@ -237,6 +243,21 @@ def train(
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
if not accelerator.is_local_main_process:
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
eval_dataset.cleanup_cache_files()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
@@ -286,7 +307,9 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
@@ -345,7 +368,13 @@ def train(
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
with model.summon_full_params():
|
||||
model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(model),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
Reference in New Issue
Block a user