more fixes, position_ids seems broken
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
# pylint: skip-file
|
||||
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
|
||||
@numba.njit
|
||||
@@ -108,7 +108,7 @@ class MultipackDistributedDataloader:
|
||||
collate_fn: Callable,
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: DistributedSampler = None,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
seed: int = 0,
|
||||
):
|
||||
# Dataset
|
||||
|
||||
@@ -134,7 +134,9 @@ def load_model(
|
||||
LOG.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.max_packed_sequence_len:
|
||||
if cfg.is_llama_derived_model and (
|
||||
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
|
||||
@@ -14,7 +14,7 @@ import transformers
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
@@ -106,13 +106,18 @@ class AxolotlTrainer(Trainer):
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
@@ -153,8 +158,8 @@ def add_position_ids(sample):
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(add_position_ids)
|
||||
eval_dataset = eval_dataset.map(add_position_ids)
|
||||
# train_dataset = train_dataset.map(add_position_ids)
|
||||
# eval_dataset = eval_dataset.map(add_position_ids)
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||
total_num_steps = math.ceil(
|
||||
@@ -165,12 +170,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
/ cfg.batch_size
|
||||
)
|
||||
else:
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
sampler = RandomSampler(train_dataset)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
|
||||
Reference in New Issue
Block a user