seq_len_multiple for packing
This commit is contained in:
@@ -128,6 +128,7 @@ class MultipackDistributedDataloader:
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
seq_len_multiple: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -135,9 +136,10 @@ class MultipackDistributedDataloader:
|
||||
[len(sample["input_ids"]) for sample in self.dataset]
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
assert batch_size % seq_len_multiple == 0
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.seq_len_multiple = seq_len_multiple
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
@@ -148,7 +150,7 @@ class MultipackDistributedDataloader:
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
if self.sampler:
|
||||
@@ -164,7 +166,7 @@ class MultipackDistributedDataloader:
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length,
|
||||
c=self.seq_max_length * self.seq_len_multiple,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
@@ -181,18 +183,20 @@ class MultipackDistributedDataloader:
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(all_batches, self.batch_size):
|
||||
for batches in chunk(all_batches, self.batch_size / self.seq_len_multiple):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(idx + 1) * np.array(item[feature])
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
@@ -216,7 +220,7 @@ class MultipackDistributedDataloader:
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn([chunked_data])
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
@@ -53,26 +53,34 @@ def weighted_cross_entropy(
|
||||
|
||||
@torch.jit.script
|
||||
def create_weighted_mask(labels: torch.Tensor):
|
||||
mask = labels != -100
|
||||
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
||||
if len(labels.shape) == 1:
|
||||
labels = labels.unsqueeze(0)
|
||||
|
||||
# Create a tensor to track group ids
|
||||
group_ids = torch.zeros_like(labels).int()
|
||||
curr_group_id = 0
|
||||
weights = torch.zeros_like(labels).float()
|
||||
for i in range(labels.shape[0]):
|
||||
mask = labels[i] != -100
|
||||
|
||||
for i in range(1, len(labels)):
|
||||
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
|
||||
curr_group_id += 1 # start new group
|
||||
group_ids[i] = (
|
||||
curr_group_id if mask[i] else 0
|
||||
) # assign group id if unmasked label
|
||||
# Create a tensor to track group ids
|
||||
group_ids = torch.zeros_like(labels[i]).int()
|
||||
curr_group_id = 0
|
||||
|
||||
# Count only unmasked labels in each group
|
||||
group_counts = torch.bincount(group_ids[mask])
|
||||
for j in range(1, len(labels[i])):
|
||||
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
||||
curr_group_id += 1 # start new group
|
||||
group_ids[j] = (
|
||||
curr_group_id if mask[j] else 0
|
||||
) # assign group id if unmasked label
|
||||
|
||||
mask_weights = torch.zeros_like(labels).float()
|
||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||
# Count only unmasked labels in each group
|
||||
group_counts = torch.bincount(group_ids[mask])
|
||||
|
||||
return mask_weights
|
||||
mask_weights = torch.zeros_like(labels[i]).float()
|
||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||
|
||||
weights[i] = mask_weights
|
||||
|
||||
return weights.squeeze() # squeeze the output to match the input dimension
|
||||
|
||||
|
||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
@@ -168,6 +176,7 @@ class AxolotlTrainer(Trainer):
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
seq_len_multiple=2,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
@@ -187,16 +196,18 @@ class AxolotlTrainer(Trainer):
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
seq_len_multiple=2,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
if self.args.sample_packing:
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
@@ -262,7 +273,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
* total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
/ cfg.batch_size
|
||||
// cfg.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
@@ -284,6 +295,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
seq_len_multiple=2,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
|
||||
Reference in New Issue
Block a user