From 998763bade444ea01e11bc2aea3a572af2435de3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 18 Sep 2023 19:50:08 -0400 Subject: [PATCH] ia3 keeps casting to float32, handle it here for now --- .../monkeypatch/llama_attn_hijack_flash.py | 18 ++++++++++++++++++ src/axolotl/utils/models.py | 4 ++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4f6b71575..262c10d90 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -116,6 +116,8 @@ def flashattn_forward( attention_mask: [bsz, q_len] """ # pylint: disable=duplicate-code + original_dtype = hidden_states.dtype + bsz, q_len, _ = hidden_states.size() if not hasattr(self, "pretraining_tp"): @@ -151,6 +153,13 @@ def flashattn_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + if query_states.dtype == torch.float32: + query_states = query_states.to(dtype=original_dtype) + if key_states.dtype == torch.float32: + key_states = key_states.to(dtype=original_dtype) + if value_states.dtype == torch.float32: + value_states = value_states.to(dtype=original_dtype) + query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) @@ -309,6 +318,10 @@ def flashattn_forward( else: attn_output = self.o_proj(attn_output) + # handle conversion back for IA3 + if attn_output.dtype == torch.float32: + attn_output = attn_output.to(dtype=original_dtype) + return attn_output, None, past_key_value @@ -502,6 +515,7 @@ def llama_model_forward( ) hidden_states = inputs_embeds + original_dtype = hidden_states.dtype if self.gradient_checkpointing and self.training: if use_cache: @@ -559,6 +573,10 @@ def llama_model_forward( hidden_states = layer_outputs[0] + # handle conversion back for IA3 + if hidden_states.dtype == torch.float32: + hidden_states = hidden_states.to(dtype=original_dtype) + if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a5510d3bd..1719542a1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -431,7 +431,7 @@ def load_model( if hasattr(module, "weight"): module.to(cfg.torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) + model, peft_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") @@ -462,7 +462,7 @@ def load_model( log_gpu_memory_usage(LOG, "after adapters", model.device) # TODO resume_from_checkpoint handling - return model, lora_config + return model, peft_config def load_adapter(model, cfg, adapter, inference=False):