Compare commits
1 Commits
llama-flas
...
fix-ddp_fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34eb4e1677 |
@@ -1000,9 +1000,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.eval_steps
|
and self.cfg.eval_steps
|
||||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||||
) or False
|
) or False
|
||||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
ddp_find_unused_parameters = (
|
||||||
False if self.cfg.ddp else None
|
self.cfg.ddp_find_unused_parameters
|
||||||
|
if self.cfg.ddp_find_unused_parameters is not None
|
||||||
|
else (False if self.cfg.ddp else None)
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"ddp_find_unused_parameters"
|
||||||
|
] = ddp_find_unused_parameters
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
|
|||||||
@@ -245,6 +245,7 @@ def flashattn_forward_with_s2attn(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||||
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -373,6 +374,7 @@ def flashattn_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -768,6 +770,12 @@ def llama_model_forward(
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
@@ -817,6 +825,7 @@ def llama_model_forward(
|
|||||||
past_key_value,
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
None,
|
None,
|
||||||
|
padding_mask,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -828,6 +837,7 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -874,6 +884,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@@ -906,6 +917,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user