From b9d07aa95a7a2291907cfbeabc63f95c61570ed9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 May 2023 11:33:41 -0400 Subject: [PATCH] prepare does all this already for qlora? --- src/axolotl/utils/models.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 992abf3ed..1bcc4b0bc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -204,17 +204,17 @@ def load_model( """### Post-processing on the model Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons. """ - if cfg.adapter == "qlora": - for param in model.parameters(): - param.requires_grad = False # freeze the model - train adapters later - if param.ndim == 1: - # cast the small parameters (e.g. layernorm) to fp32 for stability - param.data = param.data.to(torch.float32) - class CastOutputToFloat(nn.Sequential): - def forward(self, x): - return super().forward(x).to(torch.float32) - - model.lm_head = CastOutputToFloat(model.lm_head) + # if cfg.adapter == "qlora": + # for param in model.parameters(): + # param.requires_grad = False # freeze the model - train adapters later + # if param.ndim == 1: + # # cast the small parameters (e.g. layernorm) to fp32 for stability + # param.data = param.data.to(torch.float32) + # class CastOutputToFloat(nn.Linear): + # def forward(self, x): + # return super().forward(x).to(torch.float32) + # + # model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias) if not tokenizer: try: @@ -255,7 +255,7 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) - if cfg.adapter and load_in_8bit and not cfg.load_4bit: + if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit: logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model)