seq_len_multiple for packing

This commit is contained in:
Wing Lian
2023-08-02 23:20:19 -04:00
parent 8378335dc9
commit 98c9bc69de
2 changed files with 43 additions and 27 deletions

View File

@@ -128,6 +128,7 @@ class MultipackDistributedDataloader:
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
seq_len_multiple: int = 1,
):
# Dataset
self.dataset = dataset
@@ -135,9 +136,10 @@ class MultipackDistributedDataloader:
[len(sample["input_ids"]) for sample in self.dataset]
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % seq_len_multiple == 0
self.sampler = sampler
self.batch_size = batch_size
self.seq_len_multiple = seq_len_multiple
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
@@ -148,7 +150,7 @@ class MultipackDistributedDataloader:
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
def generate_batches(self, set_stats=False):
if self.sampler:
@@ -164,7 +166,7 @@ class MultipackDistributedDataloader:
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length,
c=self.seq_max_length * self.seq_len_multiple,
n=self.num_replicas,
)
@@ -181,18 +183,20 @@ class MultipackDistributedDataloader:
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(all_batches, self.batch_size):
for batches in chunk(all_batches, self.batch_size / self.seq_len_multiple):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(idx + 1) * np.array(item[feature])
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
@@ -216,7 +220,7 @@ class MultipackDistributedDataloader:
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn([chunked_data])
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return

View File

@@ -53,26 +53,34 @@ def weighted_cross_entropy(
@torch.jit.script
def create_weighted_mask(labels: torch.Tensor):
mask = labels != -100
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
if len(labels.shape) == 1:
labels = labels.unsqueeze(0)
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels).int()
curr_group_id = 0
weights = torch.zeros_like(labels).float()
for i in range(labels.shape[0]):
mask = labels[i] != -100
for i in range(1, len(labels)):
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[i] = (
curr_group_id if mask[i] else 0
) # assign group id if unmasked label
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels[i]).int()
curr_group_id = 0
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
for j in range(1, len(labels[i])):
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[j] = (
curr_group_id if mask[j] else 0
) # assign group id if unmasked label
mask_weights = torch.zeros_like(labels).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
return mask_weights
mask_weights = torch.zeros_like(labels[i]).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
weights[i] = mask_weights
return weights.squeeze() # squeeze the output to match the input dimension
def trainer_weighted_loss(model_output, labels, shift_labels=True):
@@ -168,6 +176,7 @@ class AxolotlTrainer(Trainer):
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
)
)
return super().get_train_dataloader()
@@ -187,16 +196,18 @@ class AxolotlTrainer(Trainer):
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
)
)
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
if self.args.sample_packing:
labels = inputs.pop("labels")
outputs = model(**inputs)
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
return (loss, outputs) if return_outputs else loss
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
@@ -262,7 +273,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
* total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
// cfg.batch_size
)
- 1
)
@@ -284,6 +295,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
seq_len_multiple=2,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")