various bugfixes

This commit is contained in:
Wing Lian
2023-04-19 17:04:34 -04:00
parent 2624bc2f11
commit 94f5e415a3
6 changed files with 63 additions and 10 deletions

33
configs/stability_3b.yml Normal file
View File

@@ -0,0 +1,33 @@
base_model: stabilityai/stablelm-base-alpha-3b
load_in_8bit: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter:
lora_model_dir:
sequence_len: 4096
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: stable-llama-3b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./stable-llama-3b
batch_size: 128
micro_batch_size: 16
num_epochs: 1
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:

View File

@@ -159,7 +159,7 @@ def train(
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.world_size != 1
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (

View File

@@ -1,3 +1,4 @@
import logging
from typing import List
import torch
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
logging.warning("dropping batch due to tensor size mismatch")
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0

View File

@@ -65,7 +65,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
elif ds_from_hub:
ds = load_dataset(d.path, streaming=True)
else:
raise Exception("unhandled dataset load")
raise Exception(f"unhandled dataset load for {d.path}")
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(

View File

@@ -102,13 +102,20 @@ def load_model(
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
except Exception as e:
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -148,7 +155,7 @@ def load_model(
model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp:
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
if cfg.load_4bit:

View File

@@ -94,13 +94,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
trainer_kwargs["callbacks"] = [early_stop_cb]
data_collator_kwargs = {
"padding": True,
}
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
**trainer_kwargs,
)