Compare commits
2 Commits
squash_pos
...
no-seq-len
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3db6dd307 | ||
|
|
9a6e9d8d15 |
@@ -12,6 +12,5 @@ reviews:
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
auto_incremental_review: true
|
||||
chat:
|
||||
auto_reply: true
|
||||
|
||||
@@ -41,12 +41,6 @@ model, and final model output, you may need at least 3TB of free disk space to k
|
||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
```
|
||||
|
||||
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
|
||||
training of the 120B model using Baseten Truss. You can read more about this recipe on
|
||||
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
|
||||
be found on their
|
||||
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
|
||||
|
||||
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||
|
||||
@@ -67,23 +61,9 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
#### vLLM
|
||||
|
||||
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||
|
||||
Optionally, vLLM can be installed from nightly:
|
||||
|
||||
```bash
|
||||
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
```
|
||||
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
||||
```bash
|
||||
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
|
||||
```
|
||||
|
||||
#### SGLang
|
||||
|
||||
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
@@ -40,7 +40,7 @@ bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: ./outputs/last_run_prepared
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
@@ -41,7 +41,7 @@ bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: ./outputs/last_run_prepared
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
@@ -40,7 +40,7 @@ bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
@@ -53,7 +53,7 @@ bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
@@ -12,7 +12,7 @@ output_dir: ./outputs/lora-out
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sequence_len:
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ liger-kernel==0.6.1
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.55.3
|
||||
peft==0.17.0
|
||||
transformers==4.55.2
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
|
||||
4
setup.py
4
setup.py
@@ -118,9 +118,9 @@ def get_package_version():
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.8.3"],
|
||||
"flash-attn": ["flash-attn==2.8.2"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"flash-attn==2.8.2",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
|
||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu126-2.7.1"
|
||||
docker_tag = "main-py3.11-cu124-2.6.0"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
|
||||
if family in ["a10", "a10g"]:
|
||||
return modal.gpu.A10G(count=count)
|
||||
if family == "h100":
|
||||
return f"H100:{count}"
|
||||
return modal.gpu.H100(count=count)
|
||||
if family == "t4":
|
||||
return modal.gpu.T4(count=count)
|
||||
if family == "l4":
|
||||
|
||||
@@ -64,7 +64,7 @@ def do_inference(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
|
||||
@@ -97,8 +97,7 @@ def do_cli(
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||
is_preprocess = kwargs.pop("is_preprocess", True)
|
||||
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -3,12 +3,11 @@
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_sweep_configs(
|
||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[dict[str, list]]:
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
@@ -89,12 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
base_output_dir = base_config.get("output_dir", "./model-out")
|
||||
for idx, permutation in enumerate(permutations, start=1):
|
||||
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
|
||||
permutation_id = f"sweep{idx:04d}"
|
||||
permutation["output_dir"] = str(permutation_dir / permutation_id)
|
||||
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
|
||||
@@ -476,8 +476,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
if self.cfg.squash_position_ids:
|
||||
kwargs["squash_position_ids"] = True
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
|
||||
@@ -268,7 +268,10 @@ class ModelLoader:
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
and self.model.config.max_position_embeddings
|
||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||
and (
|
||||
self.cfg.sequence_len is not None
|
||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||
)
|
||||
):
|
||||
LOG.warning(
|
||||
"increasing model.config.max_position_embeddings from "
|
||||
|
||||
@@ -277,14 +277,6 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||
apply_multipack_dataloader_patch,
|
||||
)
|
||||
|
||||
LOG.info("Applying multipack dataloader patch for sample packing...")
|
||||
apply_multipack_dataloader_patch()
|
||||
|
||||
def _apply_fsdp2_bnb_patches(self):
|
||||
"""Apply FSDP2 BNB patches."""
|
||||
if (
|
||||
|
||||
@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
|
||||
if module.base_layer.bias is not None:
|
||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||
log_bias_dtype_mismatch = True
|
||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Monkey patches for the dataset fetcher to handle batches of packed indexes."""
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
@@ -6,20 +6,10 @@ import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
_ORIGINAL_MAP_DATASET_FETCHER = None
|
||||
_ORIGINAL_WORKER_LOOP = None
|
||||
_IS_PATCHED = False
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
"""
|
||||
Custom dataset fetcher that handles nested batch structures from
|
||||
MultipackBatchSampler.
|
||||
"""
|
||||
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
# Handle nested structure from MultipackBatchSampler
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
@@ -33,7 +23,6 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
# Standard batch handling
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
@@ -45,54 +34,14 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
"""Apply patches to PyTorch's DataLoader components."""
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
"""Worker loop that ensures patches are applied in worker processes."""
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
def apply_multipack_dataloader_patch():
|
||||
"""
|
||||
This patch allows DataLoader to correctly process batches that contain multiple bins
|
||||
of packed sequences.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED
|
||||
|
||||
if _IS_PATCHED:
|
||||
return
|
||||
|
||||
# Store original implementations
|
||||
_ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher
|
||||
_ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop
|
||||
|
||||
# Apply patches
|
||||
patch_fetchers()
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
|
||||
_IS_PATCHED = True
|
||||
|
||||
|
||||
def remove_multipack_dataloader_patch():
|
||||
"""Remove the monkeypatch and restore original PyTorch DataLoader behavior."""
|
||||
# pylint: disable=global-statement
|
||||
global _IS_PATCHED
|
||||
|
||||
if not _IS_PATCHED:
|
||||
return
|
||||
|
||||
if _ORIGINAL_MAP_DATASET_FETCHER:
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (
|
||||
_ORIGINAL_MAP_DATASET_FETCHER
|
||||
)
|
||||
|
||||
if _ORIGINAL_WORKER_LOOP:
|
||||
torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
|
||||
|
||||
_IS_PATCHED = False
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
|
||||
@@ -91,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and (self.max_length is None or len(result["input_ids"]) < self.max_length)
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
|
||||
@@ -253,9 +253,7 @@ def save_trained_model(
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
if ( # pylint: disable=too-many-nested-blocks
|
||||
trainer.is_fsdp_enabled or cfg.fsdp_config
|
||||
):
|
||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||
if cfg.fsdp_config or cfg.fsdp:
|
||||
if cfg.fsdp_config.final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||
@@ -287,8 +285,6 @@ def save_trained_model(
|
||||
if trainer.accelerator.is_main_process:
|
||||
# move all files in merged_path to cfg.output_dir
|
||||
for merged_file in Path(merged_path).iterdir():
|
||||
if (Path(cfg.output_dir) / merged_file.name).exists():
|
||||
(Path(cfg.output_dir) / merged_file.name).unlink()
|
||||
shutil.move(str(merged_file), cfg.output_dir)
|
||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||
|
||||
@@ -408,7 +408,7 @@ class AxolotlInputConfig(
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
sequence_len: int = Field(
|
||||
sequence_len: int | None = Field(
|
||||
default=512,
|
||||
json_schema_extra={
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
@@ -459,12 +459,6 @@ class AxolotlInputConfig(
|
||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||
},
|
||||
)
|
||||
squash_position_ids: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||
},
|
||||
)
|
||||
eval_sample_packing: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -229,7 +229,10 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
if sequence_len is not None:
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
else:
|
||||
results.append(min_sequence_len <= length)
|
||||
return results
|
||||
|
||||
|
||||
@@ -405,7 +408,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
skip_estimates = cfg.model_config_type == "mamba"
|
||||
skip_estimates = cfg.sequence_len is None or cfg.model_config_type == "mamba"
|
||||
|
||||
if (
|
||||
not skip_estimates
|
||||
|
||||
@@ -48,13 +48,7 @@ class TestBatchedSamplerPacking:
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||
apply_multipack_dataloader_patch,
|
||||
remove_multipack_dataloader_patch,
|
||||
)
|
||||
|
||||
# Apply the patch for multipack handling
|
||||
apply_multipack_dataloader_patch()
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
|
||||
@@ -107,14 +101,10 @@ class TestBatchedSamplerPacking:
|
||||
for pack in batch:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
try:
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
finally:
|
||||
# Clean up: remove the patch after the test
|
||||
remove_multipack_dataloader_patch()
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
|
||||
Reference in New Issue
Block a user