fix for position_ids w packing

This commit is contained in:
Wing Lian
2023-07-21 20:31:54 -04:00
parent 2e295c9f94
commit 2f2974196d

View File

@@ -5,6 +5,7 @@ import math
import os import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
@@ -166,10 +167,15 @@ def add_position_ids(sample):
return sample return sample
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing: if cfg.sample_packing:
# train_dataset = train_dataset.map(add_position_ids) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
# eval_dataset = eval_dataset.map(add_position_ids) train_dataset = train_dataset.filter(drop_long).map(add_position_ids)
eval_dataset = eval_dataset.filter(drop_long).map(add_position_ids)
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = ( total_num_steps = (
@@ -417,8 +423,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
data_collator_kwargs["pad_to_multiple_of"] = 8 data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention: if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import ( from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens, add_mem_tokens,
get_mem_id, get_mem_id,