Compare commits
4 Commits
split-batc
...
no-seq-len
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3db6dd307 | ||
|
|
9a6e9d8d15 | ||
|
|
c10eb811fa | ||
|
|
0eef385b1a |
@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
|
|||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len:
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ class VllmServeCliArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of tensor parallel workers to use."},
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
)
|
)
|
||||||
|
data_parallel_size: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||||
|
},
|
||||||
|
)
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
default=None, # nosec B104
|
default=None, # nosec B104
|
||||||
metadata={"help": "Host address to run the server on."},
|
metadata={"help": "Host address to run the server on."},
|
||||||
|
|||||||
@@ -268,7 +268,10 @@ class ModelLoader:
|
|||||||
hasattr(self.model, "config")
|
hasattr(self.model, "config")
|
||||||
and hasattr(self.model.config, "max_position_embeddings")
|
and hasattr(self.model.config, "max_position_embeddings")
|
||||||
and self.model.config.max_position_embeddings
|
and self.model.config.max_position_embeddings
|
||||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
and (
|
||||||
|
self.cfg.sequence_len is not None
|
||||||
|
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"increasing model.config.max_position_embeddings from "
|
"increasing model.config.max_position_embeddings from "
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
and len(result["input_ids"]) < self.max_length
|
and (self.max_length is None or len(result["input_ids"]) < self.max_length)
|
||||||
and add_eos_token
|
and add_eos_token
|
||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from axolotl.utils.data.shared import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
drop_long_seq_in_dataset,
|
handle_long_seq_in_dataset,
|
||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
@@ -339,9 +339,9 @@ def _load_raw_datasets(
|
|||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
if split == "test" and cfg.eval_sequence_len:
|
if split == "test" and cfg.eval_sequence_len:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
else:
|
else:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq_in_dataset(
|
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
|
"""
|
||||||
|
Truncate samples whose sequence length is too long (> sequence_len)
|
||||||
|
or drop those too short (< min_sequence_len).
|
||||||
|
"""
|
||||||
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Batched (input_ids is a list of lists)
|
||||||
|
for i, seq in enumerate(input_ids):
|
||||||
|
length = len(seq)
|
||||||
|
if length < min_sequence_len:
|
||||||
|
results.append(False)
|
||||||
|
elif length > sequence_len:
|
||||||
|
sample["input_ids"][i] = seq[:sequence_len]
|
||||||
|
if "attention_mask" in sample:
|
||||||
|
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||||
|
if "labels" in sample:
|
||||||
|
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||||
|
if "position_ids" in sample:
|
||||||
|
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||||
|
results.append(True)
|
||||||
|
else:
|
||||||
|
results.append(True)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def handle_long_seq_in_dataset(
|
||||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
@@ -192,8 +221,21 @@ def drop_long_seq_in_dataset(
|
|||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
|
if excess_length_strategy == "truncate":
|
||||||
|
process_fn = functools.partial(
|
||||||
|
truncate_long_seq,
|
||||||
|
sequence_len=sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
)
|
||||||
|
drop_long_kwargs["desc"] = (
|
||||||
|
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process_fn = drop_long
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
drop_long,
|
process_fn,
|
||||||
batched=True,
|
batched=True,
|
||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
@@ -201,6 +243,11 @@ def drop_long_seq_in_dataset(
|
|||||||
if prior_len:
|
if prior_len:
|
||||||
dropped = prior_len - len(dataset)
|
dropped = prior_len - len(dataset)
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
action = (
|
||||||
|
"truncated/filtered"
|
||||||
|
if excess_length_strategy == "truncate"
|
||||||
|
else "dropped"
|
||||||
|
)
|
||||||
|
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -408,12 +408,18 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
sequence_len: int = Field(
|
sequence_len: int | None = Field(
|
||||||
default=512,
|
default=512,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sequence_len: int | None = Field(
|
eval_sequence_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
results = []
|
results = []
|
||||||
for seq in input_ids:
|
for seq in input_ids:
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
if sequence_len is not None:
|
||||||
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
|
else:
|
||||||
|
results.append(min_sequence_len <= length)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
skip_estimates = cfg.model_config_type == "mamba"
|
skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_estimates
|
not skip_estimates
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies.completion import load
|
from axolotl.prompt_strategies.completion import load
|
||||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
|
|||||||
Reference in New Issue
Block a user