Compare commits
1 Commits
llama-flas
...
fix-ddp_fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34eb4e1677 |
@@ -1000,9 +1000,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
ddp_find_unused_parameters = (
|
||||
self.cfg.ddp_find_unused_parameters
|
||||
if self.cfg.ddp_find_unused_parameters is not None
|
||||
else (False if self.cfg.ddp else None)
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"ddp_find_unused_parameters"
|
||||
] = ddp_find_unused_parameters
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
report_to = None
|
||||
if self.cfg.use_wandb:
|
||||
|
||||
@@ -245,6 +245,7 @@ def flashattn_forward_with_s2attn(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -373,6 +374,7 @@ def flashattn_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -768,6 +770,12 @@ def llama_model_forward(
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
@@ -817,6 +825,7 @@ def llama_model_forward(
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
@@ -828,6 +837,7 @@ def llama_model_forward(
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
@@ -874,6 +884,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
@@ -906,6 +917,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user