Compare commits
2 Commits
release-v0
...
split-batc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb65157dcf | ||
|
|
7fd3d8abc4 |
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.2"
|
||||
__version__ = "0.13.0.dev"
|
||||
|
||||
@@ -40,12 +40,6 @@ class VllmServeCliArgs:
|
||||
default=None,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
data_parallel_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=None, # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
|
||||
@@ -424,7 +424,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
not self.cfg.pretraining_sample_concatenation
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
|
||||
@@ -272,6 +272,20 @@ class AxolotlTrainer(
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
|
||||
if (
|
||||
self.args.accelerator_config is not None
|
||||
and self.args.accelerator_config.split_batches
|
||||
and self.args.accelerator_config.dispatch_batches
|
||||
):
|
||||
if self.args.sample_packing and self.args.pretraining:
|
||||
if not self.args.eval_sample_packing and not is_training:
|
||||
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||
else:
|
||||
dataloader_params["batch_size"] = self.accelerator.num_processes
|
||||
elif not self.args.sample_packing and self.args.pretraining:
|
||||
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(is_training and not self.args.pretraining)
|
||||
or (not is_training and self.args.eval_sample_packing is not False)
|
||||
|
||||
0
src/axolotl/exception_handling.py
Normal file
0
src/axolotl/exception_handling.py
Normal file
@@ -28,7 +28,7 @@ from axolotl.utils.data.shared import (
|
||||
)
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
handle_long_seq_in_dataset,
|
||||
drop_long_seq_in_dataset,
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||
@@ -339,9 +339,9 @@ def _load_raw_datasets(
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
if split == "test" and cfg.eval_sequence_len:
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
else:
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
if cfg.sample_packing:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
|
||||
@@ -148,36 +148,7 @@ def deduplicate_and_log_datasets(
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len)
|
||||
or drop those too short (< min_sequence_len).
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
results = []
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
if length < min_sequence_len:
|
||||
results.append(False)
|
||||
elif length > sequence_len:
|
||||
sample["input_ids"][i] = seq[:sequence_len]
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||
if "labels" in sample:
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||
results.append(True)
|
||||
else:
|
||||
results.append(True)
|
||||
return results
|
||||
|
||||
|
||||
def handle_long_seq_in_dataset(
|
||||
def drop_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
@@ -221,21 +192,8 @@ def handle_long_seq_in_dataset(
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
truncate_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
drop_long_kwargs["desc"] = (
|
||||
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||
)
|
||||
else:
|
||||
process_fn = drop_long
|
||||
|
||||
dataset = dataset.filter(
|
||||
process_fn,
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
@@ -243,11 +201,6 @@ def handle_long_seq_in_dataset(
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
action = (
|
||||
"truncated/filtered"
|
||||
if excess_length_strategy == "truncate"
|
||||
else "dropped"
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -414,12 +414,6 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
|
||||
Reference in New Issue
Block a user