Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
c4f4f81bed Merge branch 'main' into map-dataset-fetcher-fix 2025-06-26 11:20:05 -04:00
Dan Saunders
4ebd4aae3d handle possibly empty batch 2025-06-26 10:59:27 -04:00
NanoCode012
927bf530bc fix(doc): default messages example used wrong key (#2832)
* fix(doc): default messages example used wrong key

* feat: add links to SP, multi-gpu, multi-node on readme
2025-06-26 10:47:31 -04:00
github-actions[bot]
18954ba100 chore: update pre-commit hooks (#2821) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-06-26 10:46:53 -04:00
Wing Lian
d8cf66edbd use fork for multiprocess start method for packing in parallel (#2830) 2025-06-25 13:17:33 -04:00
10 changed files with 25 additions and 13 deletions

View File

@@ -19,7 +19,7 @@ repos:
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 7.2.0 rev: 7.3.0
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pylint-dev/pylint - repo: https://github.com/pylint-dev/pylint
@@ -27,7 +27,7 @@ repos:
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0 rev: v1.16.1
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.8.3 rev: 1.8.5
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -43,7 +43,7 @@ Features:
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models. - **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM). - **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference. - **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more! - **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -9,7 +9,7 @@ order: 3
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]} {"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
``` ```
See [configs](../config-reference.qmd) for full configs and supported templates. See [configs](../config-reference.qmd) for full configs and supported templates.

View File

@@ -116,6 +116,7 @@ class AxolotlTrainer(
sequential=self.args.sample_packing_sequentially, sequential=self.args.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=self.args.dataset_num_proc, num_processes=self.args.dataset_num_proc,
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
) )
len(sampler) len(sampler)

View File

@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
}, },
) )
sample_packing_mp_start_method: str | None = field(
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
multipack_real_batches: bool = field( multipack_real_batches: bool = field(
default=False, default=False,
metadata={"help": "Use real batches for efficient training."}, metadata={"help": "Use real batches for efficient training."},

View File

@@ -9,6 +9,9 @@ from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher): class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index): def fetch(self, possibly_batched_index):
if not possibly_batched_index:
return self.collate_fn([])
if isinstance(possibly_batched_index[0], list): if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index] data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index): for i, possibly_batched_index_ in enumerate(possibly_batched_index):

View File

@@ -208,9 +208,6 @@ class SequenceParallelContextManager:
self.original_seq_len = 0 self.original_seq_len = 0
self.pad_len = 0 self.pad_len = 0
# Store kwargs passed to model forward pass
self.original_kwargs: None | dict[str, torch.Tensor] = None
# Create a partially applied version of the apply_sequence_parallelism function # Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial( self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism, apply_sequence_parallelism,
@@ -263,9 +260,6 @@ class SequenceParallelContextManager:
# Any excess positional arguments are kept as-is # Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :] remaining_args = args[len(forward_params) :]
# Store original kwargs
self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()}
# Apply sequence parallelism to updated kwargs # Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = ( updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs) self.apply_sequence_parallelism(updated_kwargs)

View File

@@ -127,7 +127,7 @@ def pack_parallel(
bin_size: int, bin_size: int,
num_processes: int | None = None, num_processes: int | None = None,
safe_mode: bool = True, safe_mode: bool = True,
mp_start_method: str | None = "spawn", mp_start_method: str | None = "fork",
) -> list[list[int]]: ) -> list[list[int]]:
"""Pack sequences into bins using parallel processing. """Pack sequences into bins using parallel processing.
@@ -266,6 +266,7 @@ class MultipackBatchSampler(BatchSampler):
bin_size: int = 200, # The max number of samples that can be packed in a single bin bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
self.bin_size = bin_size self.bin_size = bin_size
self.num_processes = num_processes self.num_processes = num_processes
self.safe_mode = safe_mode self.safe_mode = safe_mode
self.mp_start_method = mp_start_method
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
@@ -338,8 +340,9 @@ class MultipackBatchSampler(BatchSampler):
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
group_size=self.group_size, group_size=self.group_size,
bin_size=self.bin_size, bin_size=self.bin_size,
num_processes=self.num_processes, num_processes=max(4, self.num_processes) if self.num_processes else 4,
safe_mode=self.safe_mode, safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,
) )
# Map bin indices back to original indices # Map bin indices back to original indices

View File

@@ -393,6 +393,12 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "Whether to pack samples sequentially"}, json_schema_extra={"description": "Whether to pack samples sequentially"},
) )
sample_packing_mp_start_method: str | None = Field(
default=None,
json_schema_extra={
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
},
)
eval_sample_packing: bool | None = Field( eval_sample_packing: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -467,6 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sequential=cfg.sample_packing_sequentially, sequential=cfg.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=cfg.dataset_processes, num_processes=cfg.dataset_processes,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
) )
data_loader = DataLoader( data_loader = DataLoader(