fix for position_ids w packing
This commit is contained in:
@@ -5,6 +5,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@@ -166,10 +167,15 @@ def add_position_ids(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def drop_long_seq(sample, sequence_len=2048):
|
||||||
|
return len(sample["input_ids"]) <= sequence_len
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# train_dataset = train_dataset.map(add_position_ids)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
# eval_dataset = eval_dataset.map(add_position_ids)
|
train_dataset = train_dataset.filter(drop_long).map(add_position_ids)
|
||||||
|
eval_dataset = eval_dataset.filter(drop_long).map(add_position_ids)
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
@@ -417,8 +423,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
add_mem_tokens,
|
add_mem_tokens,
|
||||||
get_mem_id,
|
get_mem_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user