attention_mask not needed for training (#642)
* attention_mask not needed for training * specifically don't use attention mask for phi * use a different check for phi * small fixes since phi removed some values from their config
This commit is contained in:
@@ -711,12 +711,8 @@ class ParallelBlock(nn.Module):
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
|
||||
mlp_cls = mlp.pop("mlp_cls")
|
||||
if mlp_cls == "fused_mlp":
|
||||
self.mlp = FusedMLP(config=config, **mlp)
|
||||
else:
|
||||
self.mlp = MLP(config=config, **mlp)
|
||||
self.mixer = MHA(config, layer_idx=block_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -76,7 +76,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
cfg, train_dataset, eval_dataset, tokenizer
|
||||
)
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
|
||||
@@ -397,7 +397,7 @@ def disable_datasets_caching():
|
||||
set_caching_enabled(True)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
@@ -414,6 +414,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user