detab the code to check
This commit is contained in:
@@ -29,66 +29,66 @@ PATCHED_CONTEXT_CODE = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
ORIGINAL_LLAMA_FCLM_CODE = """
|
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_LLAMA_FCLM_CODE = """
|
PATCHED_LLAMA_FCLM_CODE = """
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
num_items_in_batch = kwargs.pop("num_items_in_batch")
|
num_items_in_batch = kwargs.pop("num_items_in_batch")
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user