From d1aed4c8e5322eb122bceac9549c63d738ef6b4b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 16 Apr 2023 06:59:47 -0400 Subject: [PATCH] deepspeed doesn't work with flash-attn, and the gpu savings w flash attn are better than the deepspeed headaches --- configs/cerebras_1_3B_alpaca.yml | 2 +- configs/llama_65B_alpaca.yml | 2 +- configs/llama_7B_alpaca.yml | 2 +- configs/pythia_1_2B_alpaca.yml | 2 +- ds_config.json | 33 ++-------- scripts/finetune.py | 107 ++++++++++++++++--------------- 6 files changed, 68 insertions(+), 80 deletions(-) diff --git a/configs/cerebras_1_3B_alpaca.yml b/configs/cerebras_1_3B_alpaca.yml index 48800cb93..c02058a5d 100644 --- a/configs/cerebras_1_3B_alpaca.yml +++ b/configs/cerebras_1_3B_alpaca.yml @@ -34,6 +34,6 @@ train_on_inputs: false group_by_length: false bf16: True tf32: True +early_stopping_patience: resume_from_checkpoint: local_rank: -deepspeed: diff --git a/configs/llama_65B_alpaca.yml b/configs/llama_65B_alpaca.yml index 3ad22cb39..c6a81b572 100644 --- a/configs/llama_65B_alpaca.yml +++ b/configs/llama_65B_alpaca.yml @@ -36,6 +36,6 @@ train_on_inputs: false group_by_length: false bf16: true tf32: true +early_stopping_patience: resume_from_checkpoint: local_rank: -deepspeed: diff --git a/configs/llama_7B_alpaca.yml b/configs/llama_7B_alpaca.yml index 6dfb77bb1..6d9cfd9b2 100644 --- a/configs/llama_7B_alpaca.yml +++ b/configs/llama_7B_alpaca.yml @@ -36,6 +36,6 @@ train_on_inputs: false group_by_length: false bf16: true tf32: true +early_stopping_patience: resume_from_checkpoint: local_rank: -deepspeed: diff --git a/configs/pythia_1_2B_alpaca.yml b/configs/pythia_1_2B_alpaca.yml index 1b733c054..303b246c8 100644 --- a/configs/pythia_1_2B_alpaca.yml +++ b/configs/pythia_1_2B_alpaca.yml @@ -36,6 +36,6 @@ train_on_inputs: false group_by_length: false bf16: True tf32: True +early_stopping_patience: resume_from_checkpoint: local_rank: -deepspeed: diff --git a/ds_config.json b/ds_config.json index 2fc3be619..05fc98177 100644 --- a/ds_config.json +++ b/ds_config.json @@ -1,6 +1,6 @@ { "bf16": { - "enabled": "auto", + "enabled": "auto" }, "fp16": { "enabled": "auto", @@ -10,15 +10,6 @@ "hysteresis": 2, "min_loss_scale": 1 }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, "scheduler": { "type": "WarmupLR", "params": { @@ -28,29 +19,19 @@ } }, "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, + "stage": 2, "overlap_comm": true, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, "contiguous_gradients": true, - "sub_group_size": 1e9, "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": true + "reduce_scatter": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 5, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false + "wall_clock_breakdown": false, + "round_robin_gradients": true } diff --git a/scripts/finetune.py b/scripts/finetune.py index bfa4fb4f4..c85994f8f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -20,7 +20,13 @@ from peft import ( PeftModel, ) from torch import nn -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaForCausalLM, + LlamaTokenizer, + EarlyStoppingCallback, +) # add src to the pythonpath so we don't need to pip install this from transformers.trainer_pt_utils import get_parameter_names @@ -54,11 +60,11 @@ def setup_wandb_env_vars(cfg): os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id -def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): +def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False): if adapter != "lora": raise NotImplementedError(f"{adapter} peft adapter not available") if "llama" in base_model: - if cfg.device not in ["mps", "cpu"]: + if cfg.device not in ["mps", "cpu"] and inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() @@ -185,7 +191,7 @@ def do_inference(cfg, model, tokenizer): generated = model.generate(inputs=batch["input_ids"], do_sample=True, use_cache=True, repetition_penalty=1.1, - max_new_tokens=50, + max_new_tokens=100, temperature=0.9, top_p=0.95, top_k=40, @@ -224,19 +230,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = min(int(0.005 * total_num_steps), 10) save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) training_arguments_kwargs = {} - - if not cfg.deepspeed: - warmup_steps = min(int(0.03 * total_num_steps), 100) - logging_steps = min(int(0.005 * total_num_steps), 10) - - training_arguments_kwargs["warmup_steps"] = warmup_steps - training_arguments_kwargs["logging_steps"] = logging_steps - training_arguments_kwargs["logging_steps"] = logging_steps - training_arguments_kwargs["bf16"] = cfg.bf16 - training_arguments_kwargs["tf32"] = cfg.tf32 + training_arguments_kwargs["bf16"] = cfg.bf16 + training_arguments_kwargs["tf32"] = cfg.tf32 + training_arguments_kwargs["warmup_steps"] = warmup_steps + training_arguments_kwargs["logging_steps"] = logging_steps training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, @@ -258,37 +260,40 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) trainer_kwargs = {} + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] - if not cfg.deepspeed: - decay_parameters = get_parameter_names(model, [nn.LayerNorm]) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], - "weight_decay": training_args.weight_decay, - }, - { - "params": [ - p for n, p in model.named_parameters() if n not in decay_parameters - ], - "weight_decay": 0.0, - }, - ] + adam_bnb_optim = bnb.optim.Adam8bit( + optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + lr=training_args.learning_rate, + ) - adam_bnb_optim = bnb.optim.Adam8bit( - optimizer_grouped_parameters, - betas=(training_args.adam_beta1, training_args.adam_beta2), - eps=training_args.adam_epsilon, - lr=training_args.learning_rate, + lr_scheduler = transformers.get_cosine_schedule_with_warmup( + adam_bnb_optim, + training_args.warmup_steps, + total_num_steps, + ) + trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) + + if cfg.early_stopping_patience: + early_stop_cb = EarlyStoppingCallback( + cfg.early_stopping_patience, ) - - lr_scheduler = transformers.get_cosine_schedule_with_warmup( - adam_bnb_optim, - training_args.warmup_steps, - total_num_steps, - ) - trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) - + trainer_kwargs["callbacks"] = [early_stop_cb] trainer = transformers.Trainer( model=model, @@ -340,7 +345,7 @@ def train( # Load the model and tokenizer model, tokenizer, lora_config = load_model( - cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter + cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs) ) if "inference" in kwargs: @@ -422,17 +427,19 @@ def train( lora_config.save_pretrained(cfg.output_dir) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - signal.signal( - signal.SIGINT, - lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), - ) + if cfg.local_rank == 0: + signal.signal( + signal.SIGINT, + lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), + ) logging.info("Starting trainer...") trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") - model.save_pretrained(cfg.output_dir) + if cfg.local_rank == 0: + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading + logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + model.save_pretrained(cfg.output_dir) if __name__ == "__main__":