more fixes, position_ids seems broken

This commit is contained in:
Wing Lian
2023-07-18 16:47:08 -04:00
parent 66774011c4
commit 58045f0816
3 changed files with 21 additions and 19 deletions

View File

@@ -1,10 +1,10 @@
# pylint: skip-file
from typing import Any, Callable, List
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler
from torch.utils.data import DistributedSampler, Sampler
@numba.njit
@@ -108,7 +108,7 @@ class MultipackDistributedDataloader:
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: DistributedSampler = None,
sampler: Union[Sampler, DistributedSampler] = None,
seed: int = 0,
):
# Dataset

View File

@@ -134,7 +134,9 @@ def load_model(
LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and cfg.max_packed_sequence_len:
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")

View File

@@ -14,7 +14,7 @@ import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -106,13 +106,18 @@ class AxolotlTrainer(Trainer):
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
)
)
return super().get_eval_dataloader(eval_dataset)
@@ -153,8 +158,8 @@ def add_position_ids(sample):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids)
eval_dataset = eval_dataset.map(add_position_ids)
# train_dataset = train_dataset.map(add_position_ids)
# eval_dataset = eval_dataset.map(add_position_ids)
if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = math.ceil(
@@ -165,12 +170,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
/ cfg.batch_size
)
else:
sampler = DistributedSampler(
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed or 42,
)
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,