various bugfixes

This commit is contained in:
Wing Lian
2023-04-19 17:04:34 -04:00
parent 2624bc2f11
commit 94f5e415a3
6 changed files with 63 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import logging
from typing import List
import torch
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
logging.warning("dropping batch due to tensor size mismatch")
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0

View File

@@ -65,7 +65,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
elif ds_from_hub:
ds = load_dataset(d.path, streaming=True)
else:
raise Exception("unhandled dataset load")
raise Exception(f"unhandled dataset load for {d.path}")
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(

View File

@@ -102,13 +102,20 @@ def load_model(
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
except Exception as e:
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -148,7 +155,7 @@ def load_model(
model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp:
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
if cfg.load_4bit:

View File

@@ -94,13 +94,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
trainer_kwargs["callbacks"] = [early_stop_cb]
data_collator_kwargs = {
"padding": True,
}
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
**trainer_kwargs,
)