more fixes, position_ids seems broken
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import DistributedSampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
@numba.njit
|
||||||
@@ -108,7 +108,7 @@ class MultipackDistributedDataloader:
|
|||||||
collate_fn: Callable,
|
collate_fn: Callable,
|
||||||
seq_max_length: int = 2048,
|
seq_max_length: int = 2048,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
sampler: DistributedSampler = None,
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
|
|||||||
@@ -134,7 +134,9 @@ def load_model(
|
|||||||
LOG.info("patching with xpos rope")
|
LOG.info("patching with xpos rope")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.max_packed_sequence_len:
|
if cfg.is_llama_derived_model and (
|
||||||
|
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||||
|
):
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import transformers
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
@@ -106,13 +106,18 @@ class AxolotlTrainer(Trainer):
|
|||||||
self, eval_dataset: Optional[Dataset] = None
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
|
eval_dataset = (
|
||||||
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
)
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
return MultipackDistributedDataloader(
|
return self.accelerator.prepare(
|
||||||
eval_dataset,
|
MultipackDistributedDataloader(
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
eval_dataset,
|
||||||
seq_max_length=self.args.max_seq_length,
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
seq_max_length=self.args.max_seq_length,
|
||||||
sampler=eval_sampler,
|
collate_fn=self.data_collator,
|
||||||
|
sampler=eval_sampler,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
@@ -153,8 +158,8 @@ def add_position_ids(sample):
|
|||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(add_position_ids)
|
# train_dataset = train_dataset.map(add_position_ids)
|
||||||
eval_dataset = eval_dataset.map(add_position_ids)
|
# eval_dataset = eval_dataset.map(add_position_ids)
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
total_num_steps = math.ceil(
|
total_num_steps = math.ceil(
|
||||||
@@ -165,12 +170,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampler = DistributedSampler(
|
sampler = RandomSampler(train_dataset)
|
||||||
train_dataset,
|
|
||||||
num_replicas=1,
|
|
||||||
rank=0,
|
|
||||||
seed=cfg.seed or 42,
|
|
||||||
)
|
|
||||||
data_loader = MultipackDistributedDataloader(
|
data_loader = MultipackDistributedDataloader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user