Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
faed3905fd version tag 0.13.2
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-22 10:58:38 -05:00
90 changed files with 323 additions and 560 deletions

View File

@@ -38,7 +38,7 @@ jobs:
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
axolotl_extras: vllm
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0

View File

@@ -45,7 +45,7 @@ jobs:
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
axolotl_extras: "fbgemm-gpu,vllm"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130

View File

@@ -115,10 +115,10 @@ jobs:
- name: Pre-Download dataset fixture
run: |
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -132,7 +132,7 @@ jobs:
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -210,7 +210,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -219,10 +219,10 @@ jobs:
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache ls
run: hf cache scan
gate-skip-e2e:
needs: [pre-commit]
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -258,7 +258,7 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
@@ -269,7 +269,7 @@ jobs:
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout

View File

@@ -224,6 +224,9 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -509,6 +512,7 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -1 +1 @@
0.14.0
0.13.2

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=3 \
pytest -v --durations=10 -n2 --maxfail=4 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

View File

@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
]
},
{

View File

@@ -1,77 +0,0 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -19,6 +19,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true

View File

@@ -12,6 +12,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora

View File

@@ -47,5 +47,6 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,6 +60,3 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -9,17 +9,17 @@ liger-kernel==0.6.4
# END section
packaging==26.0
huggingface_hub>=1.1.7
huggingface_hub>=0.36.0
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.0.0
transformers==4.57.6
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.27.1
trl==0.27.0
hf_xet==1.2.0
kernels==0.11.5
trackio>=0.13.0
typing-extensions>=4.15.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
)

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -24,6 +24,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
@@ -41,6 +42,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(

View File

@@ -14,6 +14,8 @@ from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from huggingface_hub import split_torch_state_dict_into_shards
@@ -38,15 +40,17 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as `model.safetensors`.
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
@@ -72,7 +76,11 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
@@ -90,12 +98,19 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None:
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -108,11 +123,13 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
@@ -121,6 +138,8 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
@@ -158,7 +177,7 @@ def merge_fsdp_weights(
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
@@ -191,6 +210,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()

View File

@@ -102,10 +102,12 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
@@ -119,7 +121,7 @@ def do_quantize(
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps: int | float = 0
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
@@ -230,10 +230,6 @@ class TrainerBuilderBase(abc.ABC):
else:
warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1:
warmup_steps = 2
@@ -246,6 +242,7 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -533,7 +530,9 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
@@ -546,7 +545,6 @@ class TrainerBuilderBase(abc.ABC):
arg_map = {
"dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -373,18 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
@@ -449,9 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -738,38 +738,43 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -1,10 +1,12 @@
"""Module for TRL RL trainers"""
from trl import RewardTrainer
from trl.experimental.cpo import CPOTrainer
from trl.experimental.kto import KTOTrainer
from trl.experimental.orpo import ORPOTrainer
from trl.experimental.prm import PRMTrainer
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin

View File

@@ -8,11 +8,7 @@ from dataclasses import dataclass, field
from typing import Optional, Type
from transformers import TrainingArguments
from trl import RewardConfig
from trl.experimental.cpo import CPOConfig
from trl.experimental.kto import KTOConfig
from trl.experimental.orpo import ORPOConfig
from trl.experimental.prm import PRMConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
```
## Usage
@@ -36,7 +36,6 @@ plugins:
- cohere
- cohere2
- deepseek_v3
- exaone4
- gemma
- gemma2
- gemma3
@@ -46,11 +45,8 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4_moe_lite
- glm46v
- glm4v
- glm4v_moe
- glm_image
- gpt_oss
- granite
- granitemoe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
)

View File

@@ -1,7 +0,0 @@
from .args import KernelsArgs
from .plugin import KernelsPlugin
__all__ = [
"KernelsArgs",
"KernelsPlugin",
]

View File

@@ -1,35 +0,0 @@
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
def check_use_kernels(cls, data):
if data.get("use_kernels") is not True:
LOG.warning(
"`use_kernels` must be set to True to use this. Automatically setting it to True."
)
data["use_kernels"] = True
return data
@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
)
data["experts_implementation"] = "eager"
return data

View File

@@ -1,61 +0,0 @@
from kernels import (
LayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
},
}
}
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -12,6 +12,7 @@ def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
@@ -21,6 +22,7 @@ def save_compressed_model(
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
@@ -32,6 +34,7 @@ def save_compressed_model(
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -26,6 +26,7 @@ from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
@@ -225,7 +226,6 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._configure_experts_implementation()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
@@ -233,10 +233,6 @@ class ModelLoader:
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _configure_experts_implementation(self):
if self.cfg.experts_implementation is not None:
self.model.set_experts_implementation(self.cfg.experts_implementation)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
@@ -438,7 +434,7 @@ class ModelLoader:
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForImageTextToText
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
@@ -480,7 +476,6 @@ class ModelLoader:
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["dtype"] = self.cfg.torch_dtype
is_ds_zero3 = is_deepspeed_zero3_enabled()
@@ -675,7 +670,7 @@ class ModelLoader:
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -793,7 +788,6 @@ class ModelLoader:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:

View File

@@ -220,6 +220,13 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,

View File

@@ -31,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()

View File

@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -1,51 +0,0 @@
"""
eaft (entropy-aware focal training) loss implementation
weights examples by entropy approximation from top-k logits
Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py
"""
import torch
import torch.nn.functional as F
def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):
"""
compute eaft loss with entropy weighting
args:
outputs: model outputs containing logits
labels: target labels for computing loss
num_items_in_batch: for sample packing support
alpha: exponent for entropy weighting (default 1.0)
k: number of top logits for entropy approximation (default 20)
"""
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
vocab_size = shift_logits.size(-1)
shift_logits_view = shift_logits.view(-1, vocab_size)
shift_labels_view = shift_labels.view(-1)
mask = shift_labels_view != -100
with torch.no_grad():
top_k_logits, _ = torch.topk(
shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1
)
top_k_probs = F.softmax(top_k_logits, dim=-1)
entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)
weights = torch.pow(entropy, alpha)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])
weighted_loss = per_token_loss * weights
if num_items_in_batch is not None:
loss = weighted_loss.sum() / num_items_in_batch
else:
loss = weighted_loss.mean()
return loss

View File

@@ -1,5 +1,5 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
"""
import importlib
@@ -12,11 +12,11 @@ LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonBackend
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
# Get original source
original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
@@ -41,7 +41,7 @@ def apply_mistral_tokenizer_image_patch():
)
# Load necessary imports from the module
module_name = MistralCommonBackend.__module__
module_name = MistralCommonTokenizer.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
@@ -79,7 +79,7 @@ def apply_mistral_tokenizer_image_patch():
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonBackend.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonBackend tensor conversion patch")
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonBackend patching")
LOG.warning("Could not find target code for MistralCommonTokenizer patching")

View File

@@ -155,6 +155,7 @@ class ReLoRACallback(TrainerCallback):
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
@@ -213,7 +214,7 @@ class ReLoRACallback(TrainerCallback):
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder)
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
return control

View File

@@ -52,15 +52,9 @@ def patch_prepare_context_parallel_inputs() -> None:
if item in patched_source:
items_to_import.append(item)
# Use a separate namespace to capture the exec'd function
namespace = {}
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
exec(patched_source, namespace)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
# Explicitly get the function from the namespace
axolotl_prepare_context_parallel_inputs = namespace[
"axolotl_prepare_context_parallel_inputs"
]
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)

View File

@@ -14,6 +14,7 @@ from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -429,7 +430,7 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
def __init__(
self,
processor,
processor: Mistral3Processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
@@ -492,8 +493,6 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,

View File

@@ -150,8 +150,6 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
tokenize=True,
return_dict=False,
**chat_template_kwargs,
)

View File

@@ -135,13 +135,16 @@ def setup_reference_model(
return model_ref
def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
@@ -149,7 +152,9 @@ def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
_model.save_pretrained(cfg.output_dir)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
@@ -214,6 +219,7 @@ def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
@@ -222,6 +228,7 @@ def save_trained_model(
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
@@ -276,6 +283,7 @@ def save_trained_model(
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=merged_path,
safe_serialization=True,
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
@@ -322,9 +330,11 @@ def save_trained_model(
pass
elif cfg.local_rank == 0:
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(cfg.output_dir)
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
@@ -334,6 +344,7 @@ def save_trained_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
@@ -438,6 +449,7 @@ def handle_untrained_tokens_fix(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
@@ -447,6 +459,7 @@ def handle_untrained_tokens_fix(
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
@@ -470,7 +483,9 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(str(Path(cfg.output_dir)))
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def setup_model_and_trainer(
@@ -567,12 +582,15 @@ def train(
) = setup_model_and_trainer(cfg, dataset_meta)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model)
setup_signal_handler(cfg, model, safe_serialization)
setup_model_card(cfg)
# Execute the training
@@ -584,7 +602,7 @@ def train(
torch.cuda.empty_cache()
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model)
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)

View File

@@ -7,11 +7,7 @@ from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
try:
from transformers.tokenization_python import PreTrainedTokenizer
except ImportError:
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.utils.distributed import is_main_process

View File

@@ -7,11 +7,11 @@ import numpy as np
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
from torch import Tensor
from transformers.tokenization_mistral_common import MistralCommonBackend
from transformers.tokenization_mistral_common import MistralCommonTokenizer
from transformers.tokenization_utils_base import VERY_LARGE_INTEGER
class HFMistralTokenizer(MistralCommonBackend):
class HFMistralTokenizer(MistralCommonTokenizer):
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
@@ -37,19 +37,11 @@ class HFMistralTokenizer(MistralCommonBackend):
def name_or_path(self) -> str:
return self._name_or_path
@name_or_path.setter
def name_or_path(self, name_or_path: str) -> None:
self._name_or_path = name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return "[This is a dummy chat template]"
@chat_template.setter
def chat_template(self, chat_template: str | None) -> None:
pass
def _set_mode(self, mode: ValidationMode):
"""Set the mode of the MistralRequestValidator.
@@ -141,7 +133,7 @@ class HFMistralTokenizer(MistralCommonBackend):
r"""
Patched fn to pass `name_or_path` and remove extra kwargs.
Instantiate a `MistralCommonBackend` from a predefined
Instantiate a `MistralCommonTokenizer` from a predefined
tokenizer.
Args:
@@ -150,7 +142,7 @@ class HFMistralTokenizer(MistralCommonBackend):
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
- A path to a *directory* containing the tokenizer config, for instance saved
using the [`MistralCommonBackend.tokenization_mistral_common.save_pretrained`] method, e.g.,
using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g.,
`./my_model_directory/`.
mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
Validation mode for the `MistralTokenizer` tokenizer.
@@ -162,7 +154,7 @@ class HFMistralTokenizer(MistralCommonBackend):
exist.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `hf auth login` (stored in `~/.huggingface`).
when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
revision (`str`, *optional*, defaults to `"main"`):
@@ -187,12 +179,12 @@ class HFMistralTokenizer(MistralCommonBackend):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
kwargs (additional keyword arguments, *optional*):
Not supported by `MistralCommonBackend.from_pretrained`.
Not supported by `MistralCommonTokenizer.from_pretrained`.
Will raise an error if used.
"""
if init_inputs:
raise ValueError(
"`init_inputs` are not supported by `MistralCommonBackend.from_pretrained`."
"`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`."
)
# Delete trust_remote_code as it does nothing
@@ -204,7 +196,7 @@ class HFMistralTokenizer(MistralCommonBackend):
# Handle kwargs and AutoTokenizer case
if kwargs and not kwargs.keys() == {"_from_auto"}:
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonBackend.from_pretrained`."
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
)
if not os.path.isfile(pretrained_model_name_or_path):

View File

@@ -619,13 +619,6 @@ class AxolotlInputConfig(
},
)
experts_implementation: str | None = Field(
default=None,
json_schema_extra={
"description": "Which experts implementation to use for MoE models,"
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
@@ -683,24 +676,6 @@ class AxolotlInputConfig(
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
use_eaft: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable Entropy-Aware Focal Training loss (EAFT)"
},
)
eaft_alpha: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Exponent for entropy weighting in EAFT (default: 1.0)"
},
)
eaft_k: int | None = Field(
default=20,
json_schema_extra={
"description": "Number of top logits for entropy approximation (default: 20)"
},
)
tiled_mlp: bool | None = Field(
default=None,

View File

@@ -4,7 +4,7 @@ FSDP Configuration Schema
from typing import Literal
from pydantic import AliasChoices, BaseModel, Field
from pydantic import BaseModel, Field
class FSDPConfig(BaseModel):
@@ -12,11 +12,6 @@ class FSDPConfig(BaseModel):
FSDP Configuration Schema
"""
fsdp_version: int | None = Field(
validation_alias=AliasChoices("fsdp_version", "version"),
default=None,
json_schema_extra={"description": "FSDP version"},
)
activation_checkpointing: bool | None = Field(
default=None,
description="Enable activation checkpointing to reduce memory usage during forward passes",

View File

@@ -123,22 +123,10 @@ class ModelOutputConfig(BaseModel):
save_safetensors: bool | None = Field(
default=True,
json_schema_extra={
"description": "Whether to save the model using safetensors format. Defaults to True."
"description": "Save model as safetensors (require safetensors package). Default True"
},
)
@field_validator("save_safetensors")
@classmethod
def validate_save_safetensors(cls, v):
if v is False:
raise ValueError(
"save_safetensors=False is not supported in Transformers V5. "
"Transformers V5 always uses safetensors format for model serialization. "
"This field is deprecated and will be removed in a future version."
)
# Allow None and True, will default to True if None
return True if v is None else v
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""

View File

@@ -900,43 +900,6 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
fsdp_version = data.get("fsdp_version", None)
if not fsdp_version and fsdp_config and fsdp_config.get("version"):
fsdp_cfg_version = fsdp_config.pop("version")
data["fsdp_version"] = fsdp_cfg_version
data["fsdp_config"]["fsdp_version"] = fsdp_cfg_version
elif not fsdp_version and fsdp_config and fsdp_config.get("fsdp_version"):
data["fsdp_version"] = fsdp_config.get("fsdp_version")
if fsdp_version and fsdp_config and not fsdp_config.get("fsdp_version"):
data["fsdp_config"]["fsdp_version"] = fsdp_version
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
@@ -1038,6 +1001,40 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
class SystemValidationMixin:
"""Validation methods related to system and hardware configuration."""

View File

@@ -83,12 +83,6 @@ def download_smollm2_135m_model():
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_instruct_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M-Instruct", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_gptq_model():
# download the model
@@ -149,20 +143,12 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned-kto", repo_type="dataset"
)
# @pytest.fixture(scope="session", autouse=True)
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
# )
# @pytest.fixture(scope="session", autouse=True)
@@ -265,9 +251,7 @@ def download_llama_1b_model_fixture():
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@@ -277,7 +261,7 @@ def download_llama3_8b_instruct_model_fixture():
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*", "config.json"],
allow_patterns=["*token*"],
)
@@ -285,19 +269,7 @@ def download_llama3_8b_instruct_model_fixture():
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_4_reasoning_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-4-reasoning",
repo_type="model",
allow_patterns=["*token*", "config.json"],
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@@ -307,7 +279,7 @@ def download_phi_3_medium_model_fixture():
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*", "config.json"],
allow_patterns=["*token*"],
)
@@ -590,8 +562,6 @@ def test_load_fixtures(
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,
download_mlabonne_finetome_100k_dataset,
download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset,
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
download_argilla_dpo_pairs_dataset,
@@ -603,7 +573,6 @@ def test_load_fixtures(
download_llama3_8b_instruct_model_fixture,
download_phi_35_mini_model_fixture,
download_phi_3_medium_model_fixture,
download_phi_4_reasoning_model_fixture,
download_mistral_7b_model_fixture,
download_gemma_2b_model_fixture,
download_gemma2_9b_model_fixture,

View File

@@ -53,6 +53,7 @@ def fixture_base_cfg():
# Checkpointing and saving
"save_steps": 100,
"output_dir": "./model-out",
"save_safetensors": True,
"save_total_limit": 4,
"save_only_model": False,
# Hardware/performance settings

View File

@@ -10,7 +10,7 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
from ..utils import check_model_output_exists
@pytest.fixture()
@@ -39,6 +39,7 @@ def min_cfg(temp_dir):
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,
@@ -91,6 +92,7 @@ class TestCutCrossEntropyIntegration:
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,

View File

@@ -48,6 +48,7 @@ class FP8IntegrationTestCase:
"sample_packing": True,
"fp8": True,
"torch_compile": True,
"save_safetensors": True,
"save_first_step": False,
}
)

View File

@@ -11,7 +11,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
from ..utils import check_model_output_exists
class LogHooksPlugin(BasePlugin):

View File

@@ -65,6 +65,7 @@ def min_cfg(temp_dir):
},
"max_steps": 5,
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
"save_first_step": False,
}

View File

@@ -48,6 +48,7 @@ class LigerIntegrationTestCase:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"save_first_step": False,
@@ -98,6 +99,7 @@ class LigerIntegrationTestCase:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"save_first_step": False,

View File

@@ -57,6 +57,7 @@ class TestLLMCompressorIntegration:
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {

View File

@@ -220,6 +220,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
@@ -314,6 +315,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
@@ -406,6 +408,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,

View File

@@ -11,7 +11,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0, supports_fp8
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -49,7 +49,7 @@ class TestFP8FSDP2:
"""Test class for FP8 mixed precision with FSDP2 functionality."""
@require_torch_2_7_0
@supports_fp8
@require_hopper
def test_fp8_fsdp2_smoke(self, temp_dir):
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
cfg = DictDefault(
@@ -94,6 +94,7 @@ class TestFP8FSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"save_safetensors": True,
"save_first_step": False,
}
)

View File

@@ -244,7 +244,6 @@ class TestFSDP1:
verify_training_success(temp_dir)
@pytest.mark.skip("broken in transformers v5")
@pytest.mark.parametrize(
"adapter_config",
[

View File

@@ -150,10 +150,6 @@ class TestFSDP2:
},
"use_tensorboard": True,
"bf16": True,
# explicitly disable LORA kernels, as they may be auto-enabled
"lora_mlp_kernel": False,
"lora_qkv_kernel": False,
"lora_o_kernel": False,
}
)

View File

@@ -23,7 +23,6 @@ def download_model():
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
@pytest.mark.skip(reason="FIXME")
class TestMultiGPUGemma3:
"""
Test case for Gemma3 models using LoRA
@@ -33,7 +32,6 @@ class TestMultiGPUGemma3:
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
"unfrozen_parameters": ["model.language_model.*", "lm_head"],
"sequence_len": 2048,
"ddp_find_unused_parameters": True,
"sample_packing": True,

View File

@@ -901,6 +901,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
"save_first_step": False,

View File

@@ -66,6 +66,7 @@ class TestActivationCheckpointing:
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
"save_first_step": False,
"dataset_num_proc": 4,

View File

@@ -46,6 +46,7 @@ class TestLlamaPeftEmbeddings:
"flash_attention": True,
"sample_packing": False,
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
"save_first_step": False,
}

View File

@@ -58,6 +58,7 @@ class TestResumeLlama:
"save_total_limit": 5,
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
"save_first_step": False,
"include_tkps": True,
}

View File

@@ -63,6 +63,7 @@ class TestReLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
"save_first_step": False,
}

View File

@@ -57,6 +57,7 @@ class TestActivationOffloading:
"flash_attention": True,
"sample_packing": True,
"bf16": "auto",
"save_safetensors": True,
"gradient_checkpointing": True,
"activation_offloading": True,
"save_first_step": False,

View File

@@ -64,6 +64,7 @@ class TestDeepseekV3:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
@@ -112,6 +113,7 @@ class TestDeepseekV3:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}

View File

@@ -41,6 +41,7 @@ class TestDiffusion:
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 3,
@@ -96,6 +97,7 @@ class TestDiffusion:
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 2,

View File

@@ -44,6 +44,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"embedding_lr_scale": 0.5,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
@@ -88,6 +89,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"embedding_lr": 0.000005,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,

View File

@@ -61,6 +61,7 @@ class TestGemma2:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
@@ -110,6 +111,7 @@ class TestGemma2:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)

View File

@@ -60,6 +60,7 @@ class TestGemma3Text:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
@@ -109,6 +110,7 @@ class TestGemma3Text:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}

View File

@@ -43,6 +43,7 @@ class TestLlama:
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
@@ -89,6 +90,7 @@ class TestLlama:
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
@@ -132,6 +134,7 @@ class TestLlama:
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
@@ -171,6 +174,7 @@ class TestLlama:
"sample_packing": False,
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)

View File

@@ -49,6 +49,7 @@ class TestPretrainLlama:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,

View File

@@ -51,6 +51,7 @@ class TestLlamaVision(unittest.TestCase):
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
@@ -96,6 +97,7 @@ class TestLlamaVision(unittest.TestCase):
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}

View File

@@ -49,6 +49,7 @@ class TestMamba(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
"save_first_step": False,
}
)

View File

@@ -224,6 +224,7 @@ class TestCustomOptimizers(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "schedule_free_adamw",
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
"save_first_step": False,
}

View File

@@ -54,6 +54,7 @@ class TestQATLlama:
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}

View File

@@ -46,6 +46,7 @@ class TestSaveFirstStepCallback(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": True,
}
)
@@ -85,6 +86,7 @@ class TestSaveFirstStepCallback(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)

View File

@@ -50,6 +50,7 @@ class TestStreamingDatasets:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,

View File

@@ -167,13 +167,6 @@ def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def supports_fp8(test_case):
compute_capability = torch.cuda.get_device_capability()
return unittest.skipUnless(
compute_capability >= (9, 0), "test requires h100 or newer GPU"
)(test_case)
def check_tensorboard(
temp_run_dir: str,
tag: str,
@@ -200,10 +193,21 @@ def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training
checks based on adapter or not (always safetensors in Transformers V5)
checks based on adapter or not and if safetensors saves are enabled or not
"""
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
if cfg.save_safetensors:
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
# check for both, b/c in trl, it often defaults to saving safetensors
if not cfg.adapter:
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
Path(temp_dir) / "model.safetensors"
).exists()
else:
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
Path(temp_dir) / "adapter_model.safetensors"
).exists()

View File

@@ -13,7 +13,6 @@ def reload_modules(hf_hub_offline):
import datasets
import huggingface_hub.constants
# from huggingface_hub.utils import reset_sessions
# Reload the constants module first, as others depend on it
importlib.reload(huggingface_hub.constants)

View File

@@ -0,0 +1,35 @@
"""Integration tests for MistralCommonTokenizer patches."""
import pytest
class TestMistralTokenizerPatchIntegration:
"""Test MistralCommonTokenizer patch integration."""
@pytest.mark.integration
def test_mistral_tokenizer_image_patch(self):
"""Test that MistralCommonTokenizer image patch can be applied."""
try:
from transformers.tokenization_mistral_common import MistralCommonTokenizer
except ImportError:
pytest.skip("MistralCommonTokenizer not available")
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
# Store original method
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
# Apply patch
apply_mistral_tokenizer_image_patch()
# Verify patch was applied
assert (
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
), "apply_chat_template was not patched"
# Verify the method is still callable
assert callable(MistralCommonTokenizer.apply_chat_template), (
"Patched method is not callable"
)

View File

@@ -141,7 +141,6 @@ def fixture_phi35_tokenizer():
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_phi4_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
return tokenizer
@@ -179,7 +178,6 @@ def fixture_devstral_1_1_tokenizer():
@pytest.fixture(name="qwen3_tokenizer")
@enable_hf_offline
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument,redefined-outer-name

View File

@@ -37,7 +37,7 @@ PARAMETRIZE_PARAMS = [
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
# ("phi35_tokenizer", "phi_35", None, "<|end|>"), # seems to be broken w transformers v5
("phi35_tokenizer", "phi_35", None, "<|end|>"),
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
]

View File

@@ -127,7 +127,8 @@ class NormalizeConfigTestCase(unittest.TestCase):
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_cpu_ram_efficient_loading", cfg_with_version.fsdp_config)
self.assertIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("version", cfg_with_version.fsdp_config)
cfg_without_version = self._get_base_cfg() | DictDefault(
{
@@ -190,7 +191,9 @@ class NormalizeConfigTestCase(unittest.TestCase):
self.assertEqual(cfg.fsdp_config.activation_checkpointing, True)
# Check original fsdp_ keys are removed
self.assertNotIn("fsdp_version", cfg.fsdp_config)
self.assertNotIn("fsdp_state_dict_type", cfg.fsdp_config)
self.assertNotIn("fsdp_reshard_after_forward", cfg.fsdp_config)
self.assertIn("fsdp_version", cfg.fsdp_config)
# Ensure no duplicate version key
self.assertNotIn("version", cfg.fsdp_config)

View File

@@ -16,9 +16,7 @@ def metric(tokenizer):
@fixture()
def model():
return AutoModelForCausalLM.from_pretrained(
MODEL_NAME, trust_remote_code=True, dtype="float32"
)
return AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
@fixture()

View File

@@ -17,7 +17,6 @@ class TestTokenizers:
test class for the load_tokenizer fn
"""
@pytest.mark.skip("LlamaTokenizer no longer has a Fast/Slow tokenizer")
@enable_hf_offline
def test_default_use_fast(self):
cfg = DictDefault(
@@ -28,7 +27,6 @@ class TestTokenizers:
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__
@pytest.mark.skip("LlamaTokenizer no longer has a Fast/Slow tokenizer")
@enable_hf_offline
def test_dont_use_fast(self):
cfg = DictDefault(

View File

@@ -13,29 +13,17 @@ class TestFSDPValidation:
test class for pydantic fsdp validation
"""
def test_fsdp_version_from_fsdp_config(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"version": 2,
},
)
cfg = validate_config(
cfg,
)
assert cfg.fsdp_version == 2
def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_version=2,
fsdp_config={
"reshard_after_forward": True,
"fsdp_version": 2,
},
)
cfg = validate_config(
cfg,
)
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
@@ -128,10 +116,9 @@ class TestFSDPValidation:
)
cfg = validate_config(cfg)
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version == 2
for key in cfg.fsdp_config.keys():
if key != "fsdp_version":
assert not key.startswith("fsdp_")
assert cfg.fsdp_config.fsdp_version is None
for keys in cfg.fsdp_config.keys():
assert not keys.startswith("fsdp_")
assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP"
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
assert cfg.fsdp_config.reshard_after_forward is True