Compare commits
3 Commits
split-batc
...
v0.12.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cf22ae23b | ||
|
|
c10eb811fa | ||
|
|
0eef385b1a |
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.13.0.dev"
|
__version__ = "0.12.2"
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ class VllmServeCliArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of tensor parallel workers to use."},
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
)
|
)
|
||||||
|
data_parallel_size: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||||
|
},
|
||||||
|
)
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
default=None, # nosec B104
|
default=None, # nosec B104
|
||||||
metadata={"help": "Host address to run the server on."},
|
metadata={"help": "Host address to run the server on."},
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from axolotl.utils.data.shared import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
drop_long_seq_in_dataset,
|
handle_long_seq_in_dataset,
|
||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
@@ -339,9 +339,9 @@ def _load_raw_datasets(
|
|||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
if split == "test" and cfg.eval_sequence_len:
|
if split == "test" and cfg.eval_sequence_len:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
else:
|
else:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq_in_dataset(
|
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
|
"""
|
||||||
|
Truncate samples whose sequence length is too long (> sequence_len)
|
||||||
|
or drop those too short (< min_sequence_len).
|
||||||
|
"""
|
||||||
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Batched (input_ids is a list of lists)
|
||||||
|
for i, seq in enumerate(input_ids):
|
||||||
|
length = len(seq)
|
||||||
|
if length < min_sequence_len:
|
||||||
|
results.append(False)
|
||||||
|
elif length > sequence_len:
|
||||||
|
sample["input_ids"][i] = seq[:sequence_len]
|
||||||
|
if "attention_mask" in sample:
|
||||||
|
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||||
|
if "labels" in sample:
|
||||||
|
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||||
|
if "position_ids" in sample:
|
||||||
|
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||||
|
results.append(True)
|
||||||
|
else:
|
||||||
|
results.append(True)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def handle_long_seq_in_dataset(
|
||||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
@@ -192,8 +221,21 @@ def drop_long_seq_in_dataset(
|
|||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
|
if excess_length_strategy == "truncate":
|
||||||
|
process_fn = functools.partial(
|
||||||
|
truncate_long_seq,
|
||||||
|
sequence_len=sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
)
|
||||||
|
drop_long_kwargs["desc"] = (
|
||||||
|
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process_fn = drop_long
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
drop_long,
|
process_fn,
|
||||||
batched=True,
|
batched=True,
|
||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
@@ -201,6 +243,11 @@ def drop_long_seq_in_dataset(
|
|||||||
if prior_len:
|
if prior_len:
|
||||||
dropped = prior_len - len(dataset)
|
dropped = prior_len - len(dataset)
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
action = (
|
||||||
|
"truncated/filtered"
|
||||||
|
if excess_length_strategy == "truncate"
|
||||||
|
else "dropped"
|
||||||
|
)
|
||||||
|
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -414,6 +414,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sequence_len: int | None = Field(
|
eval_sequence_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies.completion import load
|
from axolotl.prompt_strategies.completion import load
|
||||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
|
|||||||
Reference in New Issue
Block a user