diff --git a/README.md b/README.md index e9f0aaefb..268d5a752 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,14 @@ dataset_shard_idx: sequence_len: 2048 # max sequence length to concatenate training samples together up to # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning +# FutureWarning: This will soon be DEPRECATED max_packed_sequence_len: 1024 +# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' +sample_packing: +# you can set these packing optimizations AFTER starting a training at least once. +# The trainer will provide recommended values for these values. +sample_packing_eff_est: +total_num_tokens: # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora @@ -421,6 +428,7 @@ learning_rate: 0.00003 logging_steps: save_steps: eval_steps: +save_total_limit: # save model as safetensors (require safetensors package) save_safetensors: @@ -534,7 +542,7 @@ accelerate launch scripts/finetune.py configs/your_config.yml #### Multi-GPU -It is recommended to pre-tokenize dataset with the following before finetuning: +You can optionally pre-tokenize dataset with the following before finetuning: ```bash CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only ``` diff --git a/requirements.txt b/requirements.txt index 33bfb94a8..ae4eca8f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,8 @@ einops xformers optimum hf_transfer +numba +numpy==1.24.4 # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/scripts/finetune.py b/scripts/finetune.py index e6fea4456..a7fee5ec8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -21,9 +21,14 @@ from axolotl.logging_config import configure_logging from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import ( + calculate_total_num_steps, + process_datasets_for_packing, + setup_trainer, +) from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars @@ -232,12 +237,25 @@ def train( cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len, - seed=cfg.seed, + seed=cfg.seed or 42, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") eval_dataset = None + if is_main_process(): + # process on rank 0 first so it gets cached so other ranks load from cache + train_dataset, eval_dataset = process_datasets_for_packing( + cfg, train_dataset, eval_dataset + ) + barrier() + if not is_main_process(): + train_dataset, eval_dataset = process_datasets_for_packing( + cfg, train_dataset, eval_dataset + ) + barrier() + total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + if cfg.debug or "debug" in kwargs: LOG.info("check_dataset_labels...") check_dataset_labels( @@ -254,7 +272,7 @@ def train( log_gpu_memory_usage(LOG, "baseline", cfg.device) # Load the model and tokenizer - LOG.info("loading model and peft_config...") + LOG.info("loading model and (optionally) peft_config...") model, peft_config = load_model(cfg, tokenizer) safe_serialization = cfg.save_safetensors is True @@ -288,7 +306,9 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) return - trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) + trainer = setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + ) model.config.use_cache = False @@ -347,14 +367,12 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + trainer.save_model(cfg.output_dir) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time - if __name__ == "__main__": fire.Fire(train) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index fd82db6cb..75d8432da 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -5,7 +5,7 @@ import os from typing import List import torch -from datasets import IterableDataset +from datasets import Dataset, IterableDataset from .prompt_tokenizers import PromptTokenizingStrategy @@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy LOG = logging.getLogger("axolotl") -class TokenizedPromptDataset(IterableDataset): +class TokenizedPromptDataset(Dataset): """ - Iterable dataset that returns tokenized prompts from a stream of text files. + Dataset that returns tokenized prompts from a stream of text files. Args: prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. dataset (dataset.Dataset): Dataset with text files. @@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset): self, prompt_tokenizer: PromptTokenizingStrategy, dataset: IterableDataset, + **kwargs, ): self.prompt_tokenizer = prompt_tokenizer - self.dataset = dataset + super().__init__(self.process(dataset).data, **kwargs) - def __iter__(self): - features = self.dataset.features.keys() - num_proc = os.cpu_count() - return iter( - self.dataset.map( - self.prompt_tokenizer.tokenize_prompt, - num_proc=num_proc, - remove_columns=features, - ) + def process(self, dataset): + features = dataset.features.keys() + num_proc = min(64, os.cpu_count()) + return dataset.map( + self.prompt_tokenizer.tokenize_prompt, + num_proc=num_proc, + remove_columns=features, ) @@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset): self.tokens_dtype = torch.int64 def __iter__(self): - buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "position_ids": [], + } buffer_len = 0 for dataset in self.datasets: + idx = 0 iterator = iter(dataset) more_examples = True while more_examples: try: example = next(iterator) + idx += 1 except StopIteration: more_examples = False example = None @@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset): attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ : self.seq_length ] + position_ids = torch.cat(buffer["position_ids"], dim=-1)[ + : self.seq_length + ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] if labels.size() == input_ids.size() and ( attention_mask.size() == input_ids.size() @@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, + "position_ids": position_ids, } else: LOG.warning( @@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], + "position_ids": [], } buffer_len = 0 + idx = 1 if example: # FIXME @@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset): input_ids = example["input_ids"] attention_mask = example["attention_mask"] labels = example["labels"] - if ( - buffer["input_ids"] - and input_ids[0] == self.tokenizer.bos_token_id - ): - attention_mask[0] = 0 if add_concat_token: input_ids.append(self.concat_token_id) @@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset): input_ids, dtype=self.tokens_dtype ) attention_mask_with_concat = torch.tensor( - attention_mask, dtype=self.tokens_dtype + [idx * m for m in attention_mask], dtype=torch.int16 ) labels_with_concat = torch.tensor( labels, dtype=self.tokens_dtype ) + position_ids = torch.arange( + len(input_ids), dtype=self.tokens_dtype + ) buffer["input_ids"].append(input_ids_with_concat) buffer["attention_mask"].append(attention_mask_with_concat) buffer["labels"].append(labels_with_concat) + buffer["position_ids"].append(position_ids) buffer_len += len(input_ids) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 073786882..d900e897d 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -8,9 +8,18 @@ import torch import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, + ) + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids + def forward( self, @@ -79,6 +88,16 @@ def forward( dtype=torch.int32, device=qkv.device, ) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + elif position_ids.shape[0] == 1: + # special handling using sample packing + qkv = rearrange(qkv, "b s ... -> (b s) ...") + cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) + cu_q_lens = cu_q_lens.squeeze() + output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) @@ -113,6 +132,7 @@ def forward( "b s (h d) -> b s h d", h=nheads, ) + return ( self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 02525b7f5..752e204f7 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -128,6 +128,7 @@ def xformers_forward( query_states, key_states, value_states, + # attn_bias=attention_mask, attn_bias=xformers.ops.LowerTriangularMask(), ) attn_weights = None diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py new file mode 100644 index 000000000..d69433baa --- /dev/null +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -0,0 +1,52 @@ +""" +expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf +""" +from typing import Optional + +import torch + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This expansion handles packed sequences so that sequences share the same attention mask integer value + when they attend to each other within that sequence. + This expansion transforms the mask to lower triangular form to prevent future peeking. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) + + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask = torch.where( + mask != 0, + torch.tensor(1).to(dtype), + torch.tensor(0).to(dtype), + ) + + # Create a block-diagonal mask. + # we multiply by the binary mask so that 0's in the original mask are correctly excluded + zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask + + # Now let's create a lower triangular mask of ones that will zero out the upper triangular part + lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( + mask.device + ) + + # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask + masked_zero_one_mask = zero_one_mask * lower_triangular_ones + inverted_mask = 1.0 - masked_zero_one_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def hijack_expand_mask(): + import transformers + + transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access + _expand_mask + ) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py new file mode 100644 index 000000000..3b007e05d --- /dev/null +++ b/src/axolotl/monkeypatch/utils.py @@ -0,0 +1,103 @@ +""" +Shared utils for the monkeypatches +""" +import torch + + +def get_cu_seqlens(attn_mask): + """generate a cumulative sequence length mask for flash attention using attn mask""" + if len(attn_mask.shape) == 1: + attn_mask = attn_mask.unsqueeze(0) + + device = attn_mask.device + results = [] + max_seq_lens = [] + + for row in attn_mask: + # Exclude zeros to avoid adding their positions to the mask + t_non_zeros = row[row != 0] + # Find where the sequence number changes (including the first position) + seq_change = torch.cat( + [ + torch.tensor([1], dtype=torch.int32, device=device), + t_non_zeros[1:] != t_non_zeros[:-1], + ] + ) + # Get the indices where the sequence changes + change_indices = torch.cat( + [ + (seq_change == 1).nonzero(as_tuple=True)[0], + torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = change_indices[1:] - change_indices[:-1] + # Calculate the length of the final sequence or padding + final_seq_length = len(row) - change_indices[-1] + # Append the length of the final sequence or padding to seq_lengths + if final_seq_length.item(): + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [final_seq_length.item()], dtype=torch.int32, device=device + ), + ] + ) + # Calculate the cumulative sequence lengths + cu_seqlens = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] + ) + max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + results.append(cu_seqlens) + max_seq_lens.append(max_seq_len) + + return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) + + +def get_cu_seqlens_from_pos_ids(position_ids): + """generate a cumulative sequence length mask for flash attention using pos ids""" + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + + device = position_ids.device + results = [] + max_seq_lens = [] + + for row in position_ids: + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + (seq_starts).nonzero(as_tuple=True)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Calculate the cumulative sequence lengths + cu_seqlens = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] + ) + # Append the padding length to the cumulative sequence lengths + if padding_length: + cu_seqlens = torch.cat( + [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] + ) + max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + results.append(cu_seqlens) + max_seq_lens.append(max_seq_len) + + return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index ea7151366..d56520bd7 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter): ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. - formatted_sys_prompt = f"### System:\n{system}\n\n" if system else "" + formatted_sys_prompt = ( + self.system_format.format(system=system) + if system and self.system_format + else "" + ) if input: res = formatted_sys_prompt + self.turn_format.format( instruction=instruction, input=input @@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter): """ def match_prompt_style(self): + # pylint: disable=duplicate-code if self.prompt_style == PromptStyle.INSTRUCT.value: self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" if self.prompt_style == PromptStyle.CHAT.value: self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" + self.system_format = "SYSTEM: {system}\n" + if self.prompt_style == PromptStyle.CHATML.value: + self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" + self.turn_no_input_format = ( + "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" + ) + self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): @@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg): cfg.train_on_inputs, cfg.sequence_len, ) + + +def load_open_orca_chatml(tokenizer, cfg): + return OpenOrcaPromptTokenizingStrategy( + OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 2da4ff112..216283582 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -16,6 +16,7 @@ class PromptStyle(Enum): INSTRUCT = "instruct" CHAT = "chat" + CHATML = "chatml" class AlpacaPrompter: @@ -25,6 +26,7 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + system_format: str turn_format: str turn_no_input_format: str prompt_style: Optional[PromptStyle] = None @@ -34,14 +36,23 @@ class AlpacaPrompter: self.match_prompt_style() def match_prompt_style(self): + # pylint: disable=duplicate-code if self.prompt_style == PromptStyle.INSTRUCT.value: self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" self.turn_no_input_format = ( "### Instruction:\n{instruction}\n\n### Response:\n" ) + self.system_format = "### System:\n{system}\n\n" if self.prompt_style == PromptStyle.CHAT.value: self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" + self.system_format = "SYSTEM: {system}\n" + if self.prompt_style == PromptStyle.CHATML.value: + self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" + self.turn_no_input_format = ( + "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" + ) + self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" def build_prompt( self, diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py new file mode 100644 index 000000000..d7acdc977 --- /dev/null +++ b/src/axolotl/utils/collators.py @@ -0,0 +1,121 @@ +""" +DataCollator for axolotl to pad labels and position_ids for packed sequences +""" +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + + +@dataclass +class DataCollatorForSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels and position_ids + + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + labels = None + if return_tensors is None: + return_tensors = self.return_tensors + + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + feat = ( + [feature[feature_name] for feature in features] + if feature_name in features[0].keys() + else None + ) + labels = feat if feat and feature_name == "labels" else labels + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if feat is not None: + max_feature_length = max(len(l) for l in feat) # noqa: E741 + if self.pad_to_multiple_of is not None: + max_feature_length = ( + (max_feature_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + for feature in features: + remainder = [pad_token_id] * ( + max_feature_length - len(feature[feature_name]) + ) + if isinstance(feature[feature_name], list): + feature[feature_name] = ( + feature[feature_name] + remainder + if padding_side == "right" + else remainder + feature[feature_name] + ) + elif padding_side == "right": + feature[feature_name] = np.concatenate( + [feature[feature_name], remainder] + ).astype(np.int64) + else: + feature[feature_name] = np.concatenate( + [remainder, feature[feature_name]] + ).astype(np.int64) + + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index ee7f16905..fc30c4ce3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,13 +1,19 @@ """Module containing data utilities""" import functools -import itertools +import hashlib import logging from hashlib import md5 from pathlib import Path -from typing import List, Tuple, Union +from typing import Tuple, Union import torch -from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from datasets import ( + Dataset, + DatasetDict, + concatenate_datasets, + load_dataset, + load_from_disk, +) from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -35,6 +41,7 @@ from axolotl.prompters import ( ShareGPTPrompter, SummarizeTLDRPrompter, ) +from axolotl.utils.distributed import barrier, is_main_process LOG = logging.getLogger("axolotl") @@ -109,6 +116,7 @@ def load_tokenized_prepared_datasets( local_path = Path(d.path) if local_path.exists(): if local_path.is_dir(): + # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` ds = load_dataset( d.path, name=d.name, @@ -262,20 +270,12 @@ def load_tokenized_prepared_datasets( raise ValueError( f"unhandled prompt tokenization strategy: {d.type} {suffix}" ) - LOG.info("tokenizing, merging, and shuffling master dataset") + LOG.info("merging datasets") + dataset = concatenate_datasets(datasets) - samples: List[int] = [] - chunk_size = 1000 - for d in datasets: - d_iter = iter(d) - while True: - chunk = list(itertools.islice(d_iter, chunk_size)) - if not chunk: - break - samples.extend(chunk) - - LOG.info("shuffle") - dataset = Dataset.from_list(samples).shuffle(seed=seed) + if len(datasets) > 1: + LOG.info("shuffle merged datasets") + dataset = dataset.shuffle(seed=seed) if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) @@ -374,6 +374,7 @@ def load_prepare_datasets( dataset = Dataset.from_list(list(constant_len_dataset)) # filter out bad data + # TODO convert to dataset.filter(...) dataset = Dataset.from_list( [ d @@ -413,7 +414,51 @@ def load_prepare_datasets( ) if cfg.val_set_size: - dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache + to_hash_train = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "train" + + "|" + + str(cfg.seed or 42) + ) + to_hash_test = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "test" + + "|" + + str(cfg.seed or 42) + ) + train_fingerprint = hashlib.md5( + to_hash_train.encode(), usedforsecurity=False + ).hexdigest() + test_fingerprint = hashlib.md5( + to_hash_test.encode(), usedforsecurity=False + ).hexdigest() + + if is_main_process(): + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed or 42, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + barrier() + if not is_main_process(): + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed or 42, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + barrier() + train_dataset = dataset["train"] eval_dataset = dataset["test"] else: diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py new file mode 100644 index 000000000..dc3261d63 --- /dev/null +++ b/src/axolotl/utils/dataloader.py @@ -0,0 +1,288 @@ +# pylint: skip-file +import hashlib +import itertools +import logging +import math +from typing import Any, Callable, List, Union + +import numba +import numpy as np +from torch.utils.data import DistributedSampler, Sampler + +LOG = logging.getLogger("axolotl.utils.dataloader") + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result, len(a) + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + """ + :param lengths: array of lengths of each sample + :param lengths_cumsum: cumulative sum of consecutive lengths + :param rank: rank for this process + :param c: length of tokens per batch + :param n: number of ranks + :return: + """ + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + result_totseqs = [] + + while True: + # binary search [left, right) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length left + batch, tot_seqs = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + # add total seqs for all ranks + result_totseqs.append(tot_seqs) + # yield batch[rank], tot_seqs, s, len(result) * c * n + return result, result_totseqs, s, len(result) * c * n + + +def chunk(iterable, n): + """ + Chunk data into tuples of length n + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(itertools.islice(it, n)): + yield batch + + +def hash_indices(lst: List[int]) -> str: + # Convert the list of integers to a string representation + concatenated = ",".join(map(str, lst)) + + # Generate the hash + sha256 = hashlib.sha256() + sha256.update(concatenated.encode()) + + return sha256.hexdigest() + + +class MultipackDistributedDataloader: + """Unpadded data loading using Multipack. + Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py + Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. + """ + + def __init__( + self, + dataset: Any, + collate_fn: Callable, + seq_max_length: int = 2048, + batch_size: int = 1, + sampler: Union[Sampler, DistributedSampler] = None, + packing_efficiency_estimate: float = 1.0, + sample_packing_seq_len_multiplier: int = 1, + device_count: int = 1, + ): + # Dataset + self.dataset = dataset + self.lengths = ( + dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ) + assert isinstance(self.lengths, np.ndarray) + assert batch_size % sample_packing_seq_len_multiplier == 0 + assert batch_size >= sample_packing_seq_len_multiplier + self.sampler = sampler + self.batch_size = batch_size + self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier + self.seq_max_length = seq_max_length + self.batch_max_length = batch_size * seq_max_length + self.collate_fn = collate_fn + + self.num_replicas = 1 + self.rank = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + self.device_count = device_count + + def generate_batches(self, set_stats=False): + LOG.info("generating packed batches") + if self.sampler: + indices = [idx for idx in self.sampler] + else: + indices = range(0, len(self.dataset)) + + LOG.info(hash_indices(indices)) + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, totseqs, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + # c=self.batch_max_length, + c=self.seq_max_length * self.sample_packing_seq_len_multiplier, + n=self.num_replicas, + ) + + batches = [[indices[b_idx] for b_idx in batch] for batch in batches] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches, totseqs + + def __iter__(self): + if hasattr(self.sampler, "set_epoch"): + new_epoch = self.sampler.epoch + 1 + self.sampler.set_epoch(new_epoch) + LOG.info(f"calling sampler.set_epoch({new_epoch})") + all_batches, _ = self.generate_batches(set_stats=True) + features = self.dataset.features.keys() + len_remaining = self._len_est() + for batches in chunk( + all_batches, self.batch_size // self.sample_packing_seq_len_multiplier + ): + chunked_data = [] + attn_mask_cum_idx = 0 + for batch in batches: + concatenated = {} + batched_data = [self.dataset[batch_idx] for batch_idx in batch] + for feature in features: + if feature == "attention_mask": + arrays = [ + (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) + for idx, item in enumerate(batched_data) + if feature in item + ] + attn_mask_cum_idx += len(batched_data) + concatenated[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) + for item in batched_data + if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + chunked_data.append(concatenated) + yield self.collate_fn(chunked_data) + len_remaining -= 1 + if not len_remaining: + return + + def _len_est(self): + lengths_sum = np.sum(self.lengths) + lengths_sum_per_device = lengths_sum // self.device_count + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return ( + math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + // self.seq_max_length + // self.batch_size + ) + - 1 + ) + + def __len__(self): + # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get + # the same share of total tokens + # if not self.eff_total_used: + # batches, _ = self.generate_batches(set_stats=True) + # LOG.info( + # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + # f"actual packing efficiency: {self.efficiency()}" + # ) + return max(1, self._len_est()) + + def len_w_stats(self): + if not self.eff_total_used: + batches, _ = self.generate_batches(set_stats=True) + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"actual packing efficiency: {self.efficiency()}" + ) + return max(1, self._len_est()) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py new file mode 100644 index 000000000..345b9640c --- /dev/null +++ b/src/axolotl/utils/distributed.py @@ -0,0 +1,41 @@ +""" +utility helpers for distributed checks +""" +import torch.distributed as dist +from accelerate import Accelerator + +accelerate = None # pylint: disable=invalid-name + + +def load_accelerate(): + global accelerate # pylint: disable=global-statement + accelerate = Accelerator() + + +def is_distributed(): + """ + Check if distributed training is initialized. + """ + global accelerate # pylint: disable=global-statement + if not accelerate: + accelerate = Accelerator() + return dist.is_available() and dist.is_initialized() + + +def barrier(): + """ + Acts as a barrier to wait for all processes. This ensures that all processes + reach the barrier before proceeding further. + """ + if is_distributed(): + dist.barrier() + + +def is_main_process(): + """ + Check if the current process is the main process. + If not in distributed mode, always return True. + """ + if not is_distributed(): + return True + return dist.get_rank() == 0 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6abbd7265..cfd85f9e5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -37,20 +37,26 @@ def load_tokenizer( tokenizer_type, cfg, ): + tokenizer_kwargs = {} use_fast = True # this is the default if cfg.tokenizer_use_fast is not None: use_fast = cfg.tokenizer_use_fast + if cfg.tokenizer_legacy is not None: + # True is the default w/ https://github.com/huggingface/transformers/pull/25224 + tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy if tokenizer_type: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, use_fast=use_fast, + **tokenizer_kwargs, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, use_fast=use_fast, + **tokenizer_kwargs, ) LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") @@ -90,8 +96,10 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit - cfg.is_llama_derived_model = "llama" in base_model or ( - cfg.model_type and "llama" in cfg.model_type.lower() + cfg.is_llama_derived_model = ( + "llama" in base_model + or (cfg.model_type and "llama" in cfg.model_type.lower()) + or cfg.is_llama_derived_model ) if cfg.is_llama_derived_model and cfg.flash_attention: @@ -136,6 +144,14 @@ def load_model( LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() + if cfg.is_llama_derived_model and ( + cfg.max_packed_sequence_len or cfg.sample_packing + ): + from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask + + LOG.info("patching _expand_mask") + hijack_expand_mask() + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: @@ -228,7 +244,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map="auto" if cfg.world_size == 1 else cfg.device_map, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: @@ -263,7 +278,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -294,7 +308,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -308,7 +321,6 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, - device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a5d2ea74e..25d0b1e82 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,19 +1,23 @@ """Module containing the Trainer class and related functions""" - import importlib import logging import math import os import sys +from contextlib import contextmanager from dataclasses import dataclass, field +from functools import partial from pathlib import Path -from typing import Optional +from typing import Optional, Union import bitsandbytes as bnb +import numpy as np import torch.cuda import transformers +from datasets import Dataset, set_caching_enabled from torch import nn from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -22,6 +26,8 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, SavePeftModelCallback, ) +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.schedulers import ( InterpolatingLogScheduler, get_cosine_schedule_with_quadratic_warmup, @@ -30,6 +36,68 @@ from axolotl.utils.schedulers import ( LOG = logging.getLogger("axolotl") +@torch.jit.script +def weighted_cross_entropy( + logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor +): + # Flatten the logits, labels, and weights tensors + logits = logits.view( + -1, logits.size(-1) + ) # logits becomes of shape [batch_size*sequence_length, vocab_size] + labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length] + weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length] + + # Compute the unweighted cross entropy loss + losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") + + # Apply the weights to the losses and compute their sum + return (weights * losses).sum() + + +@torch.jit.script +def create_weighted_mask(labels: torch.Tensor): + # Check if the tensor is 2D. If not, unsqueeze it to make it 2D + if len(labels.shape) == 1: + labels = labels.unsqueeze(0) + + weights = torch.zeros_like(labels).float() + for i in range(labels.shape[0]): + mask = labels[i] != -100 + + # Create a tensor to track group ids + group_ids = torch.zeros_like(labels[i]).int() + curr_group_id = 0 + + for j in range(1, len(labels[i])): + if mask[j] and not mask[j - 1]: # switch from masked to unmasked label + curr_group_id += 1 # start new group + group_ids[j] = ( + curr_group_id if mask[j] else 0 + ) # assign group id if unmasked label + + # Count only unmasked labels in each group + group_counts = torch.bincount(group_ids[mask]) + + mask_weights = torch.zeros_like(labels[i]).float() + mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] + + weights[i] = mask_weights + + return weights.squeeze() # squeeze the output to match the input dimension + + +def trainer_weighted_loss(model_output, labels, shift_labels=True): + logits = ( + model_output["logits"] if isinstance(model_output, dict) else model_output[0] + ) + if shift_labels: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + + weights = create_weighted_mask(labels) + return weighted_cross_entropy(logits, labels, weights) + + @dataclass class AxolotlTrainingArguments(TrainingArguments): """ @@ -40,6 +108,22 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + sample_packing_seq_len_multiplier: int = field( + default=1, + metadata={"help": "the multiplier for the max len for packed sequences"}, + ) class AxolotlTrainer(Trainer): @@ -77,6 +161,64 @@ class AxolotlTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size > 1 and self.args.sample_packing: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + return super()._get_train_sampler() + + def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + train_sampler = self._get_train_sampler() + return self.accelerator.prepare( + MultipackDistributedDataloader( + self.train_dataset, + batch_size=self._train_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=train_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + ) + return super().get_train_dataloader() + + def get_eval_dataloader( + self, eval_dataset: Optional[Dataset] = None + ) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + eval_sampler = self._get_eval_sampler(eval_dataset) + return self.accelerator.prepare( + MultipackDistributedDataloader( + eval_dataset, + batch_size=self.args.eval_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=eval_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + sample_packing_seq_len_multiplier=self.args.eval_batch_size, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + ) + return super().get_eval_dataloader(eval_dataset) + + def compute_loss(self, model, inputs, return_outputs=False): + # use one's weighted cross entropy loss calc + # if self.args.sample_packing: + # labels = inputs.pop("labels") + # outputs = model(**inputs) + # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + # return (loss, outputs) if return_outputs else loss + return super().compute_loss(model, inputs, return_outputs=return_outputs) + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ @@ -107,10 +249,121 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): return self.lr_scheduler -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) +def add_position_ids(sample): + sample["position_ids"] = torch.arange(len(sample["input_ids"])) + return sample + + +def drop_long_seq(sample, sequence_len=2048): + return len(sample["input_ids"]) <= sequence_len + + +@contextmanager +def disable_datasets_caching(): + try: + set_caching_enabled(False) + yield + finally: + set_caching_enabled(True) + + +def process_datasets_for_packing(cfg, train_dataset, eval_dataset): + if cfg.sample_packing: + drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map( + add_position_ids, num_proc=os.cpu_count() + ) + if eval_dataset: + eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map( + add_position_ids, num_proc=os.cpu_count() + ) + return train_dataset, eval_dataset + + +def calculate_total_num_steps(cfg, train_dataset, tokenizer): + if cfg.sample_packing: + # we have to drop anything longer then sequence len otherwise + # flash attention with position ids fails + if not cfg.total_num_tokens: + LOG.info("calculating total_num_tokens") + total_num_tokens = np.sum( + train_dataset.data.column("input_ids") + .to_pandas() + .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda + .values + ) + LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") + cfg.total_num_tokens = total_num_tokens + + if cfg.sample_packing_eff_est: + total_num_steps = ( + # match count to len est in dataloader + ( + math.floor( + 0.99 + * cfg.total_num_tokens + / cfg.sample_packing_eff_est + / cfg.sequence_len + // cfg.batch_size + // int(os.environ.get("WORLD_SIZE", 1)) + ) + - 1 + ) + * cfg.num_epochs + ) + LOG.info( + f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" + ) + else: + sampler = RandomSampler(train_dataset) + data_loader = MultipackDistributedDataloader( + train_dataset, + batch_size=cfg.micro_batch_size, + seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, + collate_fn=DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding="longest", + ), + sampler=sampler, + packing_efficiency_estimate=cfg.sample_packing_eff_est, + sample_packing_seq_len_multiplier=cfg.micro_batch_size, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + data_loader_len = data_loader.len_w_stats() + actual_eff = data_loader.efficiency() + LOG.info(f"data_loader_len: {data_loader_len}") + total_num_steps = int( + math.floor( + data_loader_len + * cfg.micro_batch_size + * cfg.num_epochs + // cfg.batch_size + ) + ) + LOG.info( + f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" + ) + cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0 + else: + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + LOG.info(f"total_num_steps: {total_num_steps}") + return total_num_steps + + +def setup_fsdp_envs(cfg): + os.environ["ACCELERATE_USE_FSDP"] = "true" + if cfg.fsdp_config.fsdp_sync_module_states: + os.environ["FSDP_SYNC_MODULE_STATES"] = "true" + if cfg.fsdp_config.fsdp_state_dict_type: + os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type + + +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): + if cfg.fsdp: + setup_fsdp_envs(cfg) warmup_steps = ( cfg.warmup_steps if cfg.warmup_steps is not None @@ -190,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.save_safetensors: training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors + if cfg.sample_packing_eff_est: + training_arguments_kwargs[ + "sample_packing_efficiency" + ] = cfg.sample_packing_eff_est + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg + # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps + max_seq_length=cfg.sequence_len, per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None @@ -204,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, save_steps=cfg.save_steps, output_dir=cfg.output_dir, - save_total_limit=3, + save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4, load_best_model_at_end=( cfg.load_best_model_at_end is not False and cfg.val_set_size > 0 @@ -222,6 +482,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, + sample_packing=cfg.sample_packing if cfg.sample_packing else False, + sample_packing_seq_len_multiplier=cfg.micro_batch_size, **training_arguments_kwargs, ) @@ -316,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.collator_pad_to_longest: data_collator_kwargs["padding"] = "longest" else: - data_collator_kwargs["pad_to_multiple_of"] = 8 + # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html + data_collator_kwargs["pad_to_multiple_of"] = 64 if cfg.is_llama_derived_model and cfg.landmark_attention: - from functools import partial - from axolotl.monkeypatch.llama_landmark_attn import ( add_mem_tokens, get_mem_id, @@ -348,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, - data_collator=transformers.DataCollatorForSeq2Seq( + data_collator=DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", **data_collator_kwargs, diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index db29fcdfa..97d70c4c8 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl") def validate_config(cfg): + if cfg.max_packed_sequence_len and cfg.sample_packing: + raise ValueError( + "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" + ) + if cfg.max_packed_sequence_len: + LOG.warning( + str( + PendingDeprecationWarning( + "max_packed_sequence_len will be deprecated in favor of sample_packing" + ) + ) + ) + if cfg.gradient_accumulation_steps and cfg.batch_size: raise ValueError( "please set only one of gradient_accumulation_steps or batch_size" @@ -104,6 +117,17 @@ def validate_config(cfg): + "point to its path, and remove model_revision from the config." ) + if cfg.sample_packing and cfg.sdp_attention: + # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 + raise ValueError( + "sample_packing not compatible with sdp_attention. Use flash_attention" + ) + + if cfg.sample_packing and cfg.xformers_attention: + raise ValueError( + "sample_packing not compatible with xformers_attention. Use flash_attention" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py new file mode 100644 index 000000000..289c01a86 --- /dev/null +++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py @@ -0,0 +1,30 @@ +""" +Unit tests for the monkeypatch utils +""" +import unittest + +import torch + +from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids + + +class TestMonkeyPatchUtils(unittest.TestCase): + """ + Unit test class for monkeypatch utils + """ + + def test_get_cu_seqlens_1d(self): + attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) + target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) + self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res)) + + def test_get_cu_seqlens_from_pos_ids_1d(self): + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) + target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) + self.assertTrue( + torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py new file mode 100644 index 000000000..01241c295 --- /dev/null +++ b/tests/test_expand_mask.py @@ -0,0 +1,44 @@ +""" +Unit tests for the monkey patch for expand mask to handle packed sequences +""" +import unittest + +import torch + +from axolotl.monkeypatch.llama_expand_mask import _expand_mask + + +class TestExpandMask(unittest.TestCase): + """ + Test class for attention mask expansion for packed sequences + """ + + def test_output(self): + mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]]) + dtype = torch.float32 + expected_output = torch.tensor( + [ + [ + [ + [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], + [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00], + ] + ], + [ + [ + [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], + [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38], + [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], + [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], + ] + ], + ] + ) + # Check that the output matches the expected output + self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 1f19d0ecc..da8fb7a93 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase): } ) - def test_resets_attention(self): + def test_increments_attention(self): prompter = AlpacaPrompter("chat") strat = AlpacaPromptTokenizingStrategy( prompter, @@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase): # first example doesn't have mask reset assert example["input_ids"][0] == self.tokenizer.bos_token_id assert example["attention_mask"][0] == 1 + assert example["position_ids"][0] == 0 + assert example["position_ids"][1] == 1 # but subsequent one does assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 0 + assert example["attention_mask"][next_bos_index] == 2 + assert example["position_ids"][next_bos_index] == 0 + assert example["position_ids"][next_bos_index + 1] == 1 if __name__ == "__main__": diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index b4dd3cbd7..1dd511f6b 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): "output": "Hi! How can I help?", } example = strat.tokenize_prompt(sample) - assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "### System:" - assert example["input_ids"][5:7] == [1509, 20118] # "use cot" - assert example["input_ids"][9] == 11889 # USER + assert example["input_ids"][0:5] == [ + 1, + 28962, + 1254, + 12665, + 29901, + ] # "SYSTEM:" + assert example["input_ids"][5:7] == [671, 20118] # " use cot" + assert example["input_ids"][8] == 11889 # USER class Llama2ChatTokenizationTest(unittest.TestCase): diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 112f25d33..6c5b8f27c 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase): ) ) assert "use cot" in res - assert res.startswith("### System:") + assert res.startswith("SYSTEM:") assert "### Instruction:" not in res assert "### Input:" not in res assert "alpacas" in res diff --git a/tests/test_validation.py b/tests/test_validation.py index 88c97f0b7..e956d7b40 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_packing(self): + cfg = DictDefault( + { + "max_packed_sequence_len": 2048, + } + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "max_packed_sequence_len will be deprecated in favor of sample_packing" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "max_packed_sequence_len": 2048, + "sample_packing": True, + } + ) + regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg)