PoSE context length ext (#1567)
* PoSE wip * fixes for pose splitting * set pose context len so we can pick that up seperately from the usable training context len * support min sample len and define num chunks * fix chunk splitting * support for curriculum/ordered learning with pose * fix sequence len sort * add curriculum_sampling to pydantic
This commit is contained in:
@@ -212,6 +212,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -347,6 +351,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
return SequentialSampler(self.train_dataset)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
@@ -1193,6 +1199,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to = "wandb"
|
||||||
|
|||||||
@@ -503,9 +503,17 @@ class AxolotlInputConfig(
|
|||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
|
min_sample_len: Optional[int] = None
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
|
curriculum_sampling: Optional[bool] = None
|
||||||
|
|
||||||
|
# for PoSE context length extension
|
||||||
|
use_pose: Optional[bool] = None
|
||||||
|
pose_split_on_token_ids: Optional[List[int]] = None
|
||||||
|
pose_max_context_len: Optional[int] = None
|
||||||
|
pose_num_chunks: Optional[int] = None
|
||||||
|
|
||||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||||
pretrain_multipack_attn: Optional[bool] = Field(
|
pretrain_multipack_attn: Optional[bool] = Field(
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -98,17 +99,89 @@ def add_position_ids(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def add_pose_position_ids(
|
||||||
|
sample,
|
||||||
|
max_context_len=32768,
|
||||||
|
split_on_token_ids: Optional[List[int]] = None,
|
||||||
|
chunks: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
use the PoSE technique to extend the context length by randomly skipping
|
||||||
|
positions in the context. We only want to skip right before tokens in
|
||||||
|
the split_on_token_ids list. We should attempt to randomly distribute
|
||||||
|
the skips, but we don't need the final position_ids to be the full
|
||||||
|
context_len. There may be multiple turns in the context, so we want to
|
||||||
|
make sure we take into account the maximum possible number of skips
|
||||||
|
remaining in each sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
sample_len = len(input_ids)
|
||||||
|
max_skips = max_context_len - sample_len
|
||||||
|
|
||||||
|
if split_on_token_ids is None:
|
||||||
|
split_on_token_ids = []
|
||||||
|
|
||||||
|
if split_on_token_ids:
|
||||||
|
split_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
chunk_len = sample_len // chunks
|
||||||
|
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||||
|
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||||
|
if split_indices[0] < 2:
|
||||||
|
# drop the first split index if it's too close to the beginning
|
||||||
|
split_indices = split_indices[1:]
|
||||||
|
|
||||||
|
position_ids = []
|
||||||
|
prev_index = 0
|
||||||
|
total_skips = 0
|
||||||
|
|
||||||
|
for split_index in split_indices:
|
||||||
|
num_skips = (
|
||||||
|
random.randint(0, max_skips) # nosec B311
|
||||||
|
if prev_index != 0 and max_skips
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
max_skips -= num_skips
|
||||||
|
total_skips += num_skips
|
||||||
|
|
||||||
|
segment_position_ids = list(
|
||||||
|
range(prev_index + total_skips, split_index + total_skips)
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids.extend(segment_position_ids)
|
||||||
|
prev_index = split_index
|
||||||
|
|
||||||
|
sample["sequence_len"] = position_ids[-1]
|
||||||
|
position_ids = torch.tensor(position_ids)
|
||||||
|
|
||||||
|
sample["position_ids"] = position_ids
|
||||||
|
sample["length"] = len(position_ids)
|
||||||
|
assert len(position_ids) == len(input_ids)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def add_length(sample):
|
def add_length(sample):
|
||||||
sample["length"] = len(sample["input_ids"])
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
return (
|
||||||
|
len(sample["input_ids"]) <= sequence_len
|
||||||
|
and len(sample["input_ids"]) >= min_sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
|
)
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
if cfg.is_preprocess:
|
if cfg.is_preprocess:
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
@@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.use_pose:
|
||||||
|
pose_kwargs = {}
|
||||||
|
if cfg.pose_num_chunks is not None:
|
||||||
|
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||||
|
pose_fn = partial(
|
||||||
|
add_pose_position_ids,
|
||||||
|
max_context_len=cfg.pose_max_context_len,
|
||||||
|
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||||
|
**pose_kwargs,
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.sort("sequence_len")
|
||||||
|
if cfg.eval_sample_packing is not False:
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
elif cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
|
|||||||
Reference in New Issue
Block a user