Compare commits

...

4 Commits

Author SHA1 Message Date
Dan Saunders
c3db6dd307 remove hardcode 2025-08-19 15:41:32 +00:00
Dan Saunders
9a6e9d8d15 no sequence length support 2025-08-19 10:25:37 -04:00
VED
c10eb811fa data_parallel_size in in VllmserveCliArgs (#3074)
* data_parallel_size in in VllmserveCliArgs

* moved to 43
2025-08-18 08:44:37 -04:00
VED
0eef385b1a [feat] truncation support with excess_length_strategy (#3068) [skip ci]
* feat:truncation support with excess_len

* pre-commit

* excess_length_strategy

* requested changes

* lint

* added handle_long_seq_in_dataset in sft

* comments improved
2025-08-18 08:39:13 -04:00
9 changed files with 79 additions and 14 deletions

View File

@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len:
sample_packing: true sample_packing: true
eval_sample_packing: true eval_sample_packing: true

View File

@@ -40,6 +40,12 @@ class VllmServeCliArgs:
default=None, default=None,
metadata={"help": "Number of tensor parallel workers to use."}, metadata={"help": "Number of tensor parallel workers to use."},
) )
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field( host: Optional[str] = field(
default=None, # nosec B104 default=None, # nosec B104
metadata={"help": "Host address to run the server on."}, metadata={"help": "Host address to run the server on."},

View File

@@ -268,7 +268,10 @@ class ModelLoader:
hasattr(self.model, "config") hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings") and hasattr(self.model.config, "max_position_embeddings")
and self.model.config.max_position_embeddings and self.model.config.max_position_embeddings
and self.cfg.sequence_len > self.model.config.max_position_embeddings and (
self.cfg.sequence_len is not None
and self.cfg.sequence_len > self.model.config.max_position_embeddings
)
): ):
LOG.warning( LOG.warning(
"increasing model.config.max_position_embeddings from " "increasing model.config.max_position_embeddings from "

View File

@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
if ( if (
result["input_ids"][-1] != self.tokenizer.eos_token_id result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length and (self.max_length is None or len(result["input_ids"]) < self.max_length)
and add_eos_token and add_eos_token
): ):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)

View File

@@ -28,7 +28,7 @@ from axolotl.utils.data.shared import (
) )
from axolotl.utils.data.utils import ( from axolotl.utils.data.utils import (
deduplicate_and_log_datasets, deduplicate_and_log_datasets,
drop_long_seq_in_dataset, handle_long_seq_in_dataset,
retry_on_request_exceptions, retry_on_request_exceptions,
) )
from axolotl.utils.data.wrappers import get_dataset_wrapper from axolotl.utils.data.wrappers import get_dataset_wrapper
@@ -339,9 +339,9 @@ def _load_raw_datasets(
if not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len: if split == "test" and cfg.eval_sequence_len:
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else: else:
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing: if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)

View File

@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset return dataset, other_dataset
def drop_long_seq_in_dataset( def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
results = []
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset: ) -> Dataset:
"""Remove sequences longer than configured maximum from dataset. """Remove sequences longer than configured maximum from dataset.
@@ -192,8 +221,21 @@ def drop_long_seq_in_dataset(
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter( dataset = dataset.filter(
drop_long, process_fn,
batched=True, batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
@@ -201,6 +243,11 @@ def drop_long_seq_in_dataset(
if prior_len: if prior_len:
dropped = prior_len - len(dataset) dropped = prior_len - len(dataset)
if dropped: if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset") action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset return dataset

View File

@@ -408,12 +408,18 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = None
sequence_len: int = Field( sequence_len: int | None = Field(
default=512, default=512,
json_schema_extra={ json_schema_extra={
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
}, },
) )
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field( eval_sequence_len: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
results = [] results = []
for seq in input_ids: for seq in input_ids:
length = len(seq) length = len(seq)
results.append(min_sequence_len <= length <= sequence_len) if sequence_len is not None:
results.append(min_sequence_len <= length <= sequence_len)
else:
results.append(min_sequence_len <= length)
return results return results
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update: if update:
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
skip_estimates = cfg.model_config_type == "mamba" skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
if ( if (
not skip_estimates not skip_estimates

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
) )
train_dataset = concatenate_datasets([dataset_wrapper]) train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset) lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler( batch_sampler = MultipackBatchSampler(