sample_packing_seq_len_multiplier config
This commit is contained in:
@@ -128,7 +128,7 @@ class MultipackDistributedDataloader:
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
seq_len_multiple: int = 1,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -136,10 +136,11 @@ class MultipackDistributedDataloader:
|
||||
[len(sample["input_ids"]) for sample in self.dataset]
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % seq_len_multiple == 0
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.seq_len_multiple = seq_len_multiple
|
||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
@@ -166,7 +167,7 @@ class MultipackDistributedDataloader:
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.seq_len_multiple,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
@@ -183,7 +184,9 @@ class MultipackDistributedDataloader:
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(all_batches, self.batch_size // self.seq_len_multiple):
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
|
||||
@@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -176,7 +180,7 @@ class AxolotlTrainer(Trainer):
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
seq_len_multiple=2,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
@@ -197,7 +201,7 @@ class AxolotlTrainer(Trainer):
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
seq_len_multiple=2,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
@@ -295,7 +299,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
seq_len_multiple=2,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
@@ -430,6 +434,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -523,7 +528,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.collator_pad_to_longest:
|
||||
data_collator_kwargs["padding"] = "longest"
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
|
||||
Reference in New Issue
Block a user