ia3 keeps casting to float32, handle it here for now
This commit is contained in:
@@ -116,6 +116,8 @@ def flashattn_forward(
|
|||||||
attention_mask: [bsz, q_len]
|
attention_mask: [bsz, q_len]
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
if not hasattr(self, "pretraining_tp"):
|
||||||
@@ -151,6 +153,13 @@ def flashattn_forward(
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
if query_states.dtype == torch.float32:
|
||||||
|
query_states = query_states.to(dtype=original_dtype)
|
||||||
|
if key_states.dtype == torch.float32:
|
||||||
|
key_states = key_states.to(dtype=original_dtype)
|
||||||
|
if value_states.dtype == torch.float32:
|
||||||
|
value_states = value_states.to(dtype=original_dtype)
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
@@ -309,6 +318,10 @@ def flashattn_forward(
|
|||||||
else:
|
else:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
# handle conversion back for IA3
|
||||||
|
if attn_output.dtype == torch.float32:
|
||||||
|
attn_output = attn_output.to(dtype=original_dtype)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
@@ -502,6 +515,7 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -559,6 +573,10 @@ def llama_model_forward(
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
# handle conversion back for IA3
|
||||||
|
if hidden_states.dtype == torch.float32:
|
||||||
|
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
|||||||
@@ -431,7 +431,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(cfg.torch_dtype)
|
module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, peft_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -462,7 +462,7 @@ def load_model(
|
|||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user