more fixes, position_ids seems broken

This commit is contained in:
Wing Lian
2023-07-18 16:47:08 -04:00
parent 66774011c4
commit 58045f0816
3 changed files with 21 additions and 19 deletions

View File

@@ -1,10 +1,10 @@
# pylint: skip-file # pylint: skip-file
from typing import Any, Callable, List from typing import Any, Callable, List, Union
import numba import numba
import numpy as np import numpy as np
from torch.utils.data import DistributedSampler from torch.utils.data import DistributedSampler, Sampler
@numba.njit @numba.njit
@@ -108,7 +108,7 @@ class MultipackDistributedDataloader:
collate_fn: Callable, collate_fn: Callable,
seq_max_length: int = 2048, seq_max_length: int = 2048,
batch_size: int = 1, batch_size: int = 1,
sampler: DistributedSampler = None, sampler: Union[Sampler, DistributedSampler] = None,
seed: int = 0, seed: int = 0,
): ):
# Dataset # Dataset

View File

@@ -134,7 +134,9 @@ def load_model(
LOG.info("patching with xpos rope") LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and cfg.max_packed_sequence_len: if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask") LOG.info("patching _expand_mask")

View File

@@ -14,7 +14,7 @@ import transformers
from datasets import Dataset from datasets import Dataset
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
@@ -106,13 +106,18 @@ class AxolotlTrainer(Trainer):
self, eval_dataset: Optional[Dataset] = None self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]: ) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: if self.args.sample_packing:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
return MultipackDistributedDataloader( return self.accelerator.prepare(
eval_dataset, MultipackDistributedDataloader(
batch_size=self.args.per_device_eval_batch_size, eval_dataset,
seq_max_length=self.args.max_seq_length, batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator, seq_max_length=self.args.max_seq_length,
sampler=eval_sampler, collate_fn=self.data_collator,
sampler=eval_sampler,
)
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -153,8 +158,8 @@ def add_position_ids(sample):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids) # train_dataset = train_dataset.map(add_position_ids)
eval_dataset = eval_dataset.map(add_position_ids) # eval_dataset = eval_dataset.map(add_position_ids)
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = math.ceil( total_num_steps = math.ceil(
@@ -165,12 +170,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
/ cfg.batch_size / cfg.batch_size
) )
else: else:
sampler = DistributedSampler( sampler = RandomSampler(train_dataset)
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed or 42,
)
data_loader = MultipackDistributedDataloader( data_loader = MultipackDistributedDataloader(
train_dataset, train_dataset,
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,