From 2f2974196d8a96a58568a7ed8b01a82604b9100f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 21 Jul 2023 20:31:54 -0400 Subject: [PATCH] fix for position_ids w packing --- src/axolotl/utils/trainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5009d645b..eff18f8fb 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import math import os import sys from dataclasses import dataclass, field +from functools import partial from pathlib import Path from typing import Optional, Union @@ -166,10 +167,15 @@ def add_position_ids(sample): return sample +def drop_long_seq(sample, sequence_len=2048): + return len(sample["input_ids"]) <= sequence_len + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.sample_packing: - # train_dataset = train_dataset.map(add_position_ids) - # eval_dataset = eval_dataset.map(add_position_ids) + drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) + train_dataset = train_dataset.filter(drop_long).map(add_position_ids) + eval_dataset = eval_dataset.filter(drop_long).map(add_position_ids) if cfg.sample_packing_eff_est: total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) total_num_steps = ( @@ -417,8 +423,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): data_collator_kwargs["pad_to_multiple_of"] = 8 if cfg.is_llama_derived_model and cfg.landmark_attention: - from functools import partial - from axolotl.monkeypatch.llama_landmark_attn import ( add_mem_tokens, get_mem_id,