PoSE context length ext (#1567)
* PoSE wip * fixes for pose splitting * set pose context len so we can pick that up seperately from the usable training context len * support min sample len and define num chunks * fix chunk splitting * support for curriculum/ordered learning with pose * fix sequence len sort * add curriculum_sampling to pydantic
This commit is contained in:
@@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -347,6 +351,8 @@ class AxolotlTrainer(Trainer):
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
@@ -1193,6 +1199,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = None
|
||||
if self.cfg.use_wandb:
|
||||
report_to = "wandb"
|
||||
|
||||
@@ -503,9 +503,17 @@ class AxolotlInputConfig(
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
sample_packing: Optional[bool] = None
|
||||
eval_sample_packing: Optional[bool] = None
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
|
||||
# for PoSE context length extension
|
||||
use_pose: Optional[bool] = None
|
||||
pose_split_on_token_ids: Optional[List[int]] = None
|
||||
pose_max_context_len: Optional[int] = None
|
||||
pose_num_chunks: Optional[int] = None
|
||||
|
||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||
pretrain_multipack_attn: Optional[bool] = Field(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -98,17 +99,89 @@ def add_position_ids(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def add_pose_position_ids(
|
||||
sample,
|
||||
max_context_len=32768,
|
||||
split_on_token_ids: Optional[List[int]] = None,
|
||||
chunks: int = 2,
|
||||
):
|
||||
"""
|
||||
use the PoSE technique to extend the context length by randomly skipping
|
||||
positions in the context. We only want to skip right before tokens in
|
||||
the split_on_token_ids list. We should attempt to randomly distribute
|
||||
the skips, but we don't need the final position_ids to be the full
|
||||
context_len. There may be multiple turns in the context, so we want to
|
||||
make sure we take into account the maximum possible number of skips
|
||||
remaining in each sample.
|
||||
"""
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
sample_len = len(input_ids)
|
||||
max_skips = max_context_len - sample_len
|
||||
|
||||
if split_on_token_ids is None:
|
||||
split_on_token_ids = []
|
||||
|
||||
if split_on_token_ids:
|
||||
split_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||
]
|
||||
else:
|
||||
chunk_len = sample_len // chunks
|
||||
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||
if split_indices[0] < 2:
|
||||
# drop the first split index if it's too close to the beginning
|
||||
split_indices = split_indices[1:]
|
||||
|
||||
position_ids = []
|
||||
prev_index = 0
|
||||
total_skips = 0
|
||||
|
||||
for split_index in split_indices:
|
||||
num_skips = (
|
||||
random.randint(0, max_skips) # nosec B311
|
||||
if prev_index != 0 and max_skips
|
||||
else 0
|
||||
)
|
||||
max_skips -= num_skips
|
||||
total_skips += num_skips
|
||||
|
||||
segment_position_ids = list(
|
||||
range(prev_index + total_skips, split_index + total_skips)
|
||||
)
|
||||
|
||||
position_ids.extend(segment_position_ids)
|
||||
prev_index = split_index
|
||||
|
||||
sample["sequence_len"] = position_ids[-1]
|
||||
position_ids = torch.tensor(position_ids)
|
||||
|
||||
sample["position_ids"] = position_ids
|
||||
sample["length"] = len(position_ids)
|
||||
assert len(position_ids) == len(input_ids)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return (
|
||||
len(sample["input_ids"]) <= sequence_len
|
||||
and len(sample["input_ids"]) >= min_sequence_len
|
||||
)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
drop_long = partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
with zero_first(is_main_process()):
|
||||
if cfg.is_preprocess:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
@@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
desc="Group By Length",
|
||||
)
|
||||
|
||||
if cfg.sample_packing:
|
||||
if cfg.use_pose:
|
||||
pose_kwargs = {}
|
||||
if cfg.pose_num_chunks is not None:
|
||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||
pose_fn = partial(
|
||||
add_pose_position_ids,
|
||||
max_context_len=cfg.pose_max_context_len,
|
||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||
**pose_kwargs,
|
||||
)
|
||||
train_dataset = train_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
train_dataset = train_dataset.sort("sequence_len")
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
num_proc=cfg.dataset_processes,
|
||||
|
||||
Reference in New Issue
Block a user