various bugfixes (#856)
* various bugfixes use latest tinyllama release check if val_set_size is empty first update sdp and xformers llama patches for updated upstream transformers fix system prompt when no input calculate total and total supervised tokens even when not sample packing * add fix for when eval size is estimated to be too small * should be len 1 for dataset length * add catchall kwargs
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|||||||
@@ -543,16 +543,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"dataloader_prefetch_factor"
|
"dataloader_prefetch_factor"
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
|
|
||||||
if self.cfg.eval_steps:
|
if self.cfg.val_set_size == 0:
|
||||||
|
# no eval set, so don't eval
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
|
elif self.cfg.eval_steps:
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.evaluation_strategy:
|
elif self.cfg.evaluation_strategy:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"evaluation_strategy"
|
"evaluation_strategy"
|
||||||
] = self.cfg.evaluation_strategy
|
] = self.cfg.evaluation_strategy
|
||||||
elif self.cfg.val_set_size == 0:
|
|
||||||
# no eval set, so don't eval
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ def xformers_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class AlpacaPrompter(Prompter):
|
|||||||
else:
|
else:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
if self.system_prompt
|
if self.system_no_input_prompt
|
||||||
else ""
|
else ""
|
||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
|
|||||||
@@ -181,7 +181,9 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
return (
|
return min(
|
||||||
|
1,
|
||||||
|
(
|
||||||
world_size
|
world_size
|
||||||
* math.floor(
|
* math.floor(
|
||||||
0.99
|
0.99
|
||||||
@@ -190,4 +192,5 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
// self.batch_max_len
|
// self.batch_max_len
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -142,9 +142,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset):
|
def calculate_total_num_steps(cfg, train_dataset):
|
||||||
if cfg.sample_packing:
|
|
||||||
# we have to drop anything longer then sequence len otherwise
|
|
||||||
# flash attention with position ids fails
|
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens:
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
train_dataset.data.column("input_ids")
|
train_dataset.data.column("input_ids")
|
||||||
@@ -168,6 +165,10 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|||||||
)
|
)
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# we have to drop anything longer then sequence len otherwise
|
||||||
|
# flash attention with position ids fails
|
||||||
|
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
# match count to len est in dataloader
|
# match count to len est in dataloader
|
||||||
|
|||||||
Reference in New Issue
Block a user