diff --git a/requirements.txt b/requirements.txt index cd7d9f033..82186205b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git bitsandbytes>=0.39.0 -accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b +accelerate @ git+https://github.com/huggingface/accelerate@b42c65b addict fire PyYAML==6.0 diff --git a/scripts/finetune.py b/scripts/finetune.py index 70b805ecd..488d23c11 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this +from accelerate import Accelerator from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer @@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import ( + calculate_total_num_steps, + process_datasets_for_packing, + setup_trainer, +) from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars @@ -168,6 +173,7 @@ def train( prepare_ds_only: bool = False, **kwargs, ): + accelerator = Accelerator() if Path(config).is_dir(): config = choose_config(config) @@ -237,6 +243,21 @@ def train( train_dataset = train_dataset.with_format("torch") eval_dataset = None + if accelerator.is_local_main_process: + # process on rank 0 first so it gets cached so other ranks load from cache + train_dataset, eval_dataset = process_datasets_for_packing( + cfg, train_dataset, eval_dataset + ) + accelerator.wait_for_everyone() + if not accelerator.is_local_main_process: + train_dataset, eval_dataset = process_datasets_for_packing( + cfg, train_dataset, eval_dataset + ) + + train_dataset.cleanup_cache_files() + eval_dataset.cleanup_cache_files() + total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + if cfg.debug or "debug" in kwargs: LOG.info("check_dataset_labels...") check_dataset_labels( @@ -286,7 +307,9 @@ def train( model.save_pretrained(cfg.output_dir) return - trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) + trainer = setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + ) model.config.use_cache = False @@ -345,7 +368,13 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - model.save_pretrained(cfg.output_dir) + with model.summon_full_params(): + model.save_pretrained( + cfg.output_dir, + is_main_process=trainer.accelerator.is_main_process, + save_function=trainer.accelerator.save, + state_dict=trainer.accelerator.get_state_dict(model), + ) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 4755db30b..752e204f7 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -128,8 +128,8 @@ def xformers_forward( query_states, key_states, value_states, - attn_bias=attention_mask, - # attn_bias=xformers.ops.LowerTriangularMask(), + # attn_bias=attention_mask, + attn_bias=xformers.ops.LowerTriangularMask(), ) attn_weights = None else: diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index ee7f16905..6661f8db9 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -109,13 +109,16 @@ def load_tokenized_prepared_datasets( local_path = Path(d.path) if local_path.exists(): if local_path.is_dir(): - ds = load_dataset( - d.path, - name=d.name, - data_files=d.data_files, - streaming=False, - split=None, - ) + try: + ds = load_from_disk(d.path) + except FileNotFoundError: + ds = load_dataset( + d.path, + name=d.name, + data_files=d.data_files, + streaming=False, + split=None, + ) elif local_path.is_file(): ds = load_dataset( "json", diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index f4c18c604..167f3957a 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -2,7 +2,6 @@ import itertools import logging import math -import os from typing import Any, Callable, List, Union import numba @@ -63,6 +62,14 @@ def ffd_with_result(a: np.ndarray, c: int, start_index: int): def allocate( lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int ): + """ + :param lengths: array of lengths of each sample + :param lengths_cumsum: cumulative sum of consecutive lengths + :param rank: rank for this process + :param c: length of tokens per batch + :param n: number of ranks + :return: + """ # Dynamic batch allocator, similar to Multifit # https://en.wikipedia.org/wiki/Multifit_algorithm # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) @@ -98,7 +105,7 @@ def allocate( result.append(batch[rank]) # add total seqs for all ranks result_totseqs.append(tot_seqs) - + # yield batch[rank], tot_seqs, s, len(result) * c * n return result, result_totseqs, s, len(result) * c * n @@ -129,6 +136,7 @@ class MultipackDistributedDataloader: sampler: Union[Sampler, DistributedSampler] = None, packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, + device_count: int = 1, ): # Dataset self.dataset = dataset @@ -152,6 +160,7 @@ class MultipackDistributedDataloader: self.eff_total_used = 0 self.eff_total_slots = 0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + self.device_count = device_count def generate_batches(self, set_stats=False): LOG.info("generating packed batches") @@ -233,7 +242,7 @@ class MultipackDistributedDataloader: indices = range(0, len(self.dataset)) lengths = self.lengths[indices] lengths_sum = np.cumsum(lengths)[-1] - lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1)) + lengths_sum_per_device = lengths_sum // self.device_count LOG.info( f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"total_num_tokens per device: {lengths_sum_per_device}" diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b770bb47c..3be591ef7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -238,7 +238,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map="auto" if cfg.world_size == 1 else cfg.device_map, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: @@ -273,7 +272,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -304,7 +302,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -318,7 +315,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 29b944542..fd14d5cbf 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -4,6 +4,7 @@ import logging import math import os import sys +from contextlib import contextmanager from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -12,10 +13,10 @@ from typing import Optional, Union import bitsandbytes as bnb import torch.cuda import transformers -from datasets import Dataset +from datasets import Dataset, set_caching_enabled from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler +from torch.utils.data import DataLoader, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -158,20 +159,22 @@ class AxolotlTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size > 1 and self.args.sample_packing: - return DistributedSampler( - self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=self.args.seed, - ) - return super()._get_train_sampler() + # def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + # if self.args.world_size > 1 and self.args.sample_packing: + # return DistributedSampler( + # self.train_dataset, + # num_replicas=self.args.world_size, + # rank=self.args.process_index, + # seed=self.args.seed, + # ) + # return super()._get_train_sampler() def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: train_sampler = self._get_train_sampler() - + # If set to True, the dataloader prepared is only iterated through on the + # main process and then the batches are split and broadcast to each process + self.accelerator.dispatch_batches = True return self.accelerator.prepare( MultipackDistributedDataloader( self.train_dataset, @@ -181,6 +184,7 @@ class AxolotlTrainer(Trainer): sampler=train_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, + # device_count=int(os.environ.get("WORLD_SIZE", 1)), ) ) return super().get_train_dataloader() @@ -193,6 +197,9 @@ class AxolotlTrainer(Trainer): eval_dataset if eval_dataset is not None else self.eval_dataset ) eval_sampler = self._get_eval_sampler(eval_dataset) + # If set to True, the datalaoder prepared is only iterated through on the + # main process and then the batches are split and broadcast to each process + self.accelerator.dispatch_batches = True return self.accelerator.prepare( MultipackDistributedDataloader( eval_dataset, @@ -202,6 +209,7 @@ class AxolotlTrainer(Trainer): sampler=eval_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, + # device_count=int(os.environ.get("WORLD_SIZE", 1)), ) ) return super().get_eval_dataloader(eval_dataset) @@ -253,7 +261,16 @@ def drop_long_seq(sample, sequence_len=2048): return len(sample["input_ids"]) <= sequence_len -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): +@contextmanager +def disable_datasets_caching(): + try: + set_caching_enabled(False) + yield + finally: + set_caching_enabled(True) + + +def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.sample_packing: drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) train_dataset = train_dataset.filter(drop_long).map( @@ -263,12 +280,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): eval_dataset = eval_dataset.filter(drop_long).map( add_position_ids, num_proc=os.cpu_count() ) + return train_dataset, eval_dataset + + +def calculate_total_num_steps(cfg, train_dataset, tokenizer): + if cfg.sample_packing: + # we have to drop anything longer then sequence len otherwise + # flash attention with position ids fails + total_num_tokens = ( + cfg.total_num_tokens + if cfg.total_num_tokens + else sum(len(s["input_ids"]) for s in train_dataset) + ) + if not cfg.total_num_tokens: + LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") + if cfg.sample_packing_eff_est: - total_num_tokens = ( - cfg.total_num_tokens - if cfg.total_num_tokens - else sum(len(s["input_ids"]) for s in train_dataset) - ) total_num_steps = ( # match count to len est in dataloader ( @@ -300,8 +327,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, + device_count=int(os.environ.get("WORLD_SIZE", 1)), ) data_loader_len = len(data_loader) + actual_eff = data_loader.efficiency() LOG.info(f"data_loader_len: {data_loader_len}") total_num_steps = int( math.ceil( @@ -311,10 +340,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): / cfg.batch_size ) ) + LOG.info( + f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" + ) else: total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + LOG.info(f"total_num_steps: {total_num_steps}") + return total_num_steps + + +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): warmup_steps = ( cfg.warmup_steps if cfg.warmup_steps is not None diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 3ea59f391..d063753e6 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -110,6 +110,17 @@ def validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) + if cfg.sample_packing and cfg.sdp_attention: + # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 + raise ValueError( + "sample_packing not compatible with sdp_attention. Use flash_attention" + ) + + if cfg.sample_packing and cfg.xformers_attention: + raise ValueError( + "sample_packing not compatible with xformers_attention. Use flash_attention" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25