cleanup, prep for 4bit quant support

This commit is contained in:
Wing Lian
2023-04-16 11:06:41 -04:00
parent d1aed4c8e5
commit 12de7b7cf7
3 changed files with 42 additions and 7 deletions

View File

@@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
try:
if "llama" in base_model:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
except:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
@@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
training_arguments_kwargs = {}
training_arguments_kwargs["bf16"] = cfg.bf16
if cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
@@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
gradient_checkpointing=cfg.gradient_checkpointing,
**training_arguments_kwargs,
)
trainer_kwargs = {}
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
@@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
lr=training_args.learning_rate,
)
# TODO optionally use torch.optim.OneCycleLR
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
adam_bnb_optim,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
trainer_kwargs = {}
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
@@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
optimizers=(adam_bnb_optim, lr_scheduler),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
@@ -342,6 +348,12 @@ def train(
cfg.gradient_accumulation_steps // cfg.world_size
)
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
# Load the model and tokenizer
model, tokenizer, lora_config = load_model(