sample_packing_seq_len_multiplier config

This commit is contained in:
Wing Lian
2023-08-03 08:24:33 -04:00
parent 7e1edc662a
commit b8905e2a91
2 changed files with 19 additions and 9 deletions

View File

@@ -128,7 +128,7 @@ class MultipackDistributedDataloader:
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
seq_len_multiple: int = 1,
sample_packing_seq_len_multiplier: int = 1,
):
# Dataset
self.dataset = dataset
@@ -136,10 +136,11 @@ class MultipackDistributedDataloader:
[len(sample["input_ids"]) for sample in self.dataset]
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % seq_len_multiple == 0
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.seq_len_multiple = seq_len_multiple
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
@@ -166,7 +167,7 @@ class MultipackDistributedDataloader:
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.seq_len_multiple,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
@@ -183,7 +184,9 @@ class MultipackDistributedDataloader:
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(all_batches, self.batch_size // self.seq_len_multiple):
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:

View File

@@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
class AxolotlTrainer(Trainer):
@@ -176,7 +180,7 @@ class AxolotlTrainer(Trainer):
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
)
)
return super().get_train_dataloader()
@@ -197,7 +201,7 @@ class AxolotlTrainer(Trainer):
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
)
)
return super().get_eval_dataloader(eval_dataset)
@@ -295,7 +299,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
seq_len_multiple=2,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")
@@ -430,6 +434,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1,
**training_arguments_kwargs,
)
@@ -523,7 +528,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (