cleanup, prep for 4bit quant support
This commit is contained in:
@@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
replace_llama_attn_with_flash_attn()
|
||||
|
||||
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
||||
try:
|
||||
if "llama" in base_model:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit,
|
||||
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
)
|
||||
else:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit,
|
||||
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
)
|
||||
except:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit,
|
||||
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
)
|
||||
|
||||
@@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
if cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
@@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
gradient_checkpointing=cfg.gradient_checkpointing,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
trainer_kwargs = {}
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
optimizer_grouped_parameters = [
|
||||
@@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
|
||||
# TODO optionally use torch.optim.OneCycleLR
|
||||
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||
adam_bnb_optim,
|
||||
training_args.warmup_steps,
|
||||
total_num_steps,
|
||||
)
|
||||
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
||||
|
||||
trainer_kwargs = {}
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
@@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=training_args,
|
||||
optimizers=(adam_bnb_optim, lr_scheduler),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
@@ -342,6 +348,12 @@ def train(
|
||||
cfg.gradient_accumulation_steps // cfg.world_size
|
||||
)
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, lora_config = load_model(
|
||||
|
||||
Reference in New Issue
Block a user