From 2df63ef8158b318fc069b5c2b9d3050cd023725e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 15 Apr 2023 12:16:42 -0400 Subject: [PATCH] refactor trainer setup to account for deepspeed integration --- scripts/finetune.py | 153 +++++++++++++++++++++++++------------------- 1 file changed, 86 insertions(+), 67 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index d52975a96..c271e3e82 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -16,7 +16,7 @@ from peft import ( LoraConfig, get_peft_model, prepare_model_for_int8_training, - get_peft_model_state_dict, PeftModel, + PeftModel, ) from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer @@ -214,6 +214,89 @@ def choose_config(path: Path): return chosen_file +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) + + training_arguments_kwargs = {} + + if not cfg.deepspeed: + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = min(int(0.005 * total_num_steps), 10) + + training_arguments_kwargs["warmup_steps"] = warmup_steps + training_arguments_kwargs["logging_steps"] = logging_steps + training_arguments_kwargs["logging_steps"] = logging_steps + training_arguments_kwargs["bf16"] = cfg.bf16 + training_arguments_kwargs["tf32"] = cfg.tf32 + + training_args = transformers.TrainingArguments( + per_device_train_batch_size=cfg.micro_batch_size, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + num_train_epochs=cfg.num_epochs, + learning_rate=cfg.learning_rate, + evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", + save_strategy="steps", + eval_steps=eval_steps if cfg.val_set_size > 0 else None, + save_steps=save_steps, + output_dir=cfg.output_dir, + save_total_limit=3, + load_best_model_at_end=True if cfg.val_set_size > 0 else False, + ddp_find_unused_parameters=False if cfg.ddp else None, + group_by_length=cfg.group_by_length, + report_to="wandb" if cfg.use_wandb else None, + run_name=cfg.wandb_run_name if cfg.use_wandb else None, + **training_arguments_kwargs, + ) + + trainer_kwargs = {} + + if not cfg.deepspeed: + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] + + adam_bnb_optim = bnb.optim.Adam8bit( + optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + lr=training_args.learning_rate, + ) + + lr_scheduler = transformers.get_cosine_schedule_with_warmup( + adam_bnb_optim, + training_args.warmup_steps, + total_num_steps, + ) + trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) + + + trainer = transformers.Trainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + args=training_args, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + **trainer_kwargs, + ) + + return trainer + def train( config: Path = Path("configs/"), **kwargs, @@ -308,73 +391,8 @@ def train( tokenizer, ) - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - warmup_steps = min(int(0.03 * total_num_steps), 100) - logging_steps = min(int(0.005 * total_num_steps), 10) - save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) + trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) - training_args = transformers.TrainingArguments( - per_device_train_batch_size=cfg.micro_batch_size, - gradient_accumulation_steps=cfg.gradient_accumulation_steps, - warmup_steps=warmup_steps, - num_train_epochs=cfg.num_epochs, - learning_rate=cfg.learning_rate, - bf16=cfg.bf16, - tf32=cfg.tf32, - logging_steps=logging_steps, - evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", - save_strategy="steps", - eval_steps=eval_steps if cfg.val_set_size > 0 else None, - save_steps=save_steps, - output_dir=cfg.output_dir, - save_total_limit=3, - load_best_model_at_end=True if cfg.val_set_size > 0 else False, - ddp_find_unused_parameters=False if cfg.ddp else None, - group_by_length=cfg.group_by_length, - report_to="wandb" if cfg.use_wandb else None, - run_name=cfg.wandb_run_name if cfg.use_wandb else None, - ) - - decay_parameters = get_parameter_names(model, [nn.LayerNorm]) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], - "weight_decay": training_args.weight_decay, - }, - { - "params": [ - p for n, p in model.named_parameters() if n not in decay_parameters - ], - "weight_decay": 0.0, - }, - ] - - adam_bnb_optim = bnb.optim.Adam8bit( - optimizer_grouped_parameters, - betas=(training_args.adam_beta1, training_args.adam_beta2), - eps=training_args.adam_epsilon, - lr=training_args.learning_rate, - ) - - lr_scheduler = transformers.get_cosine_schedule_with_warmup( - adam_bnb_optim, - training_args.warmup_steps, - total_num_steps, - ) - - trainer = transformers.Trainer( - model=model, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - args=training_args, - optimizers=(adam_bnb_optim, lr_scheduler), - data_collator=transformers.DataCollatorForSeq2Seq( - tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True - ), - ) model.config.use_cache = False if torch.__version__ >= "2" and sys.platform != "win32": @@ -391,6 +409,7 @@ def train( trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading model.save_pretrained(cfg.output_dir)