diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8b2b7ad6a..1fc47a87f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -158,8 +158,8 @@ def load_model( for k, v in cfg.tokens.items(): tokenizer.add_special_tokens({k: v}) - if load_in_8bit and cfg.load_4bit: - logging.info("converting model w/ prepare_model_for_int8_training") + if cfg.adapter and load_in_8bit and not cfg.load_4bit: + logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) model, lora_config = load_adapter(model, cfg, adapter) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4d4719969..73be3dbd2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,9 +17,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - warmup_steps = cfg.warmup_steps if cfg.warmup_steps is not None else min(int(0.03 * total_num_steps), 100) - logging_steps = cfg.logging_steps if cfg.logging_steps is not None else max(min(int(0.005 * total_num_steps), 10), 1) - save_steps = eval_steps = cfg.save_steps if cfg.save_steps is not None else min(int(0.05 * total_num_steps), 200) + warmup_steps = ( + cfg.warmup_steps + if cfg.warmup_steps is not None + else min(int(0.03 * total_num_steps), 100) + ) + logging_steps = ( + cfg.logging_steps + if cfg.logging_steps is not None + else max(min(int(0.005 * total_num_steps), 10), 1) + ) + save_steps = eval_steps = ( + cfg.save_steps + if cfg.save_steps is not None + else min(int(0.05 * total_num_steps), 200) + ) training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -31,19 +43,32 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["logging_steps"] = logging_steps if cfg.gradient_checkpointing is not None: if cfg.load_4bit: - from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing - gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0 - apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio) + from alpaca_lora_4bit.gradient_checkpointing import ( + apply_gradient_checkpointing, + ) + + gradient_checkpointing_ratio = ( + cfg.gradient_checkpointing_ratio + if cfg.gradient_checkpointing_ratio + else 1.0 + ) + apply_gradient_checkpointing( + model, checkpoint_ratio=gradient_checkpointing_ratio + ) else: - training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing + training_arguments_kwargs[ + "gradient_checkpointing" + ] = cfg.gradient_checkpointing if cfg.fsdp: training_arguments_kwargs["fsdp"] = cfg.fsdp if cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) - # deepspeed - if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1: + if ( + os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" + and torch.cuda.device_count() > 1 + ): if cfg.deepspeed: training_arguments_kwargs["deepspeed"] = cfg.deepspeed else: @@ -62,12 +87,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, - load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False, + load_best_model_at_end=True + if cfg.val_set_size > 0 and save_steps % eval_steps == 0 + else False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, - optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer, + optim=cfg.optimizer if cfg.optimizer else None, lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None, weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, **training_arguments_kwargs, @@ -78,22 +105,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.optimizer == "adamw_anyprecision": if Path(cfg.torchdistx_path).exists(): sys.path.append(cfg.torchdistx_path) - torchdistx = importlib.import_module('torchdistx') - if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs: + importlib.import_module("torchdistx") + if ( + cfg.optimizer == "adamw_bnb_8bit" + and not cfg.load_4bit + and not "deepspeed" in training_arguments_kwargs + ): decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "params": [ + p + for n, p in model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], "weight_decay": training_args.weight_decay, }, { "params": [ - p for n, p in model.named_parameters() if n not in decay_parameters + p + for n, p in model.named_parameters() + if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] + optimizer = bnb.optim.Adam8bit( optimizer_grouped_parameters, betas=(training_args.adam_beta1, training_args.adam_beta2),