Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
34eb4e1677 fix handling of ddp_find_unused_parameters 2024-03-14 17:45:42 -04:00
2 changed files with 19 additions and 2 deletions

View File

@@ -1000,9 +1000,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
ddp_find_unused_parameters = (
self.cfg.ddp_find_unused_parameters
if self.cfg.ddp_find_unused_parameters is not None
else (False if self.cfg.ddp else None)
)
training_arguments_kwargs[
"ddp_find_unused_parameters"
] = ddp_find_unused_parameters
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
report_to = None
if self.cfg.use_wandb:

View File

@@ -245,6 +245,7 @@ def flashattn_forward_with_s2attn(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -373,6 +374,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -768,6 +770,12 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
@@ -817,6 +825,7 @@ def llama_model_forward(
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -828,6 +837,7 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -874,6 +884,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -906,6 +917,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)