diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 64bf48664..526121f2e 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -61,6 +61,7 @@ class SaveBetterTransformerModelCallback( model = BetterTransformer.reverse(kwargs["model"]) model.save_pretrained(checkpoint_folder) + # FIXME - need to cleanup old checkpoints # since we're saving here, we don't need the trainer loop to attempt to save too b/c # the trainer will raise an exception since it can't save a BetterTransformer wrapped model diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 164296ee2..13ad7c75d 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -388,9 +388,13 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) - train_dataset = dataset["train"] - eval_dataset = dataset["test"] + if cfg.val_set_size: + dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + else: + train_dataset = dataset + eval_dataset = None return train_dataset, eval_dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 91ef96ca9..49a9b6f85 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -300,6 +300,12 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) + if cfg.sequence_len >= model.config.max_position_embeddings: + logging.warning( + f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + ) + model.config.max_position_embeddings = cfg.sequence_len + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 396036621..2e2450fba 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -80,4 +80,11 @@ def validate_config(cfg): # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 - # no 8bit adamw w bf16 + # no 8bit adaAmw w bf16 + + # GPT-NeoX + # evals broken when extending context len + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product + # attention_mask = causal_mask + attention_mask + # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3