ia3 keeps casting to float32, handle it here for now

This commit is contained in:
Wing Lian
2023-09-18 19:50:08 -04:00
parent c8e42a0f4f
commit 998763bade
2 changed files with 20 additions and 2 deletions

View File

@@ -116,6 +116,8 @@ def flashattn_forward(
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
@@ -151,6 +153,13 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
@@ -309,6 +318,10 @@ def flashattn_forward(
else:
attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value
@@ -502,6 +515,7 @@ def llama_model_forward(
)
hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -559,6 +573,10 @@ def llama_model_forward(
hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -431,7 +431,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
model, peft_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -462,7 +462,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config
return model, peft_config
def load_adapter(model, cfg, adapter, inference=False):