From 94f5e415a3550772ac82bcab719de8c160ecdea9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 19 Apr 2023 17:04:34 -0400 Subject: [PATCH] various bugfixes --- configs/stability_3b.yml | 33 +++++++++++++++++++++++++++++++++ scripts/finetune.py | 2 +- src/axolotl/datasets.py | 14 +++++++++----- src/axolotl/utils/data.py | 2 +- src/axolotl/utils/models.py | 11 +++++++++-- src/axolotl/utils/trainer.py | 11 ++++++++++- 6 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 configs/stability_3b.yml diff --git a/configs/stability_3b.yml b/configs/stability_3b.yml new file mode 100644 index 000000000..8cfd8fa8c --- /dev/null +++ b/configs/stability_3b.yml @@ -0,0 +1,33 @@ +base_model: stabilityai/stablelm-base-alpha-3b +load_in_8bit: true +datasets: + - path: vicgalle/alpaca-gpt4 + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.04 +adapter: +lora_model_dir: +sequence_len: 4096 +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - v_proj +lora_fan_in_fan_out: false +wandb_project: stable-llama-3b +wandb_watch: +wandb_run_id: +wandb_log_model: checkpoint +output_dir: ./stable-llama-3b +batch_size: 128 +micro_batch_size: 16 +num_epochs: 1 +learning_rate: 0.00003 +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true +early_stopping_patience: 3 +resume_from_checkpoint: +local_rank: diff --git a/scripts/finetune.py b/scripts/finetune.py index 4c24a3c4f..858f33f9a 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -159,7 +159,7 @@ def train( cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) choose_device(cfg) - cfg.ddp = cfg.world_size != 1 + cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} cfg.gradient_accumulation_steps = ( diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 862bd3229..deab5e438 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,3 +1,4 @@ +import logging from typing import List import torch @@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset): : self.seq_length ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - yield { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - } + if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size(): + yield { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + } + else: + logging.warning("dropping batch due to tensor size mismatch") buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer_len = 0 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 4e064a881..bbfa1aa18 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -65,7 +65,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): elif ds_from_hub: ds = load_dataset(d.path, streaming=True) else: - raise Exception("unhandled dataset load") + raise Exception(f"unhandled dataset load for {d.path}") if d.type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4f9bdfc0b..d05cc1927 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -102,13 +102,20 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, ) - else: + elif model_type: model = getattr(transformers, model_type).from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, torch_dtype=torch_dtype, device_map=cfg.device_map, ) + else: + model = AutoModelForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) except Exception as e: logging.error( "Exception raised attempting to load model, retrying with AutoModelForCausalLM" @@ -148,7 +155,7 @@ def load_model( model, lora_config = load_adapter(model, cfg, adapter) - if cfg.ddp: + if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") if cfg.load_4bit: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9f4262962..e0405357c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -94,13 +94,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) trainer_kwargs["callbacks"] = [early_stop_cb] + data_collator_kwargs = { + "padding": True, + } + if cfg.collator_pad_to_longest: + data_collator_kwargs["padding"] = "longest" + else: + data_collator_kwargs["pad_to_multiple_of"] = 8 trainer = transformers.Trainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, data_collator=transformers.DataCollatorForSeq2Seq( - tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + tokenizer, + return_tensors="pt", + **data_collator_kwargs, ), **trainer_kwargs, )