transformers v5 upgrade (#3272)

* Prepare for transformers v5 upgrade

* fix hf cli

* update for hf hub changes

* fix tokenizer apply_chat_template args

* remap include_tokens_per_second

* fix tps

* handle migration for warmup

* use latest hf hub

* Fix scan -> ls

* fix import

* fix for renaming of mistral common tokenizer -> backend

* update for fixed tokenziation for llama

* Skip phi35 tests for now

* remove mistral patch fixed upstream in huggingface/transformers#41439

* use namespacing for patch

* don't rely on sdist for e2e tests for now

* run modal ci without waiting too

* Fix dep for ci

* fix imports

* Fix fp8 check

* fsdp2 fixes

* fix version handling

* update fsdp version tests for new v5 behavior

* Fail multigpu tests after 3 failures

* skip known v5 broken tests for now and cleanup

* bump deps

* unmark skipped test

* re-enable test_fsdp_qlora_prequant_packed test

* increase multigpu ci timeout

* skip broken gemma3 test

* reduce timout back to original 120min now that the hanging test is skipped

* fix for un-necessary collator for pretraining with bsz=1

* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)

* torch_dtype deprecated

* load model in float32 for consistency with tests

* revert some test fixtures back

* use hf cache ls instead of scan

* don't strip fsdp_version

more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places

* Transformers v5 rc2 (#3347)

* bump dep

* use latest fbgemm, grab model config as part of fixture, un-skip test

* import AutoConfig

* don't need more problematic autoconfig when specifying config.json manually

* add fixtures for argilla ultrafeedback datasets

* download phi4-reasoning

* fix arg

* update tests for phi fast tokenizer changes

* use explicit model types for gemma3

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>

* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText

* chore: remove duplicate

* fix: attempt fix gemma3 text mode

* chore: lint

* ga release of v5

* need property setter for name_or_path for mistral tokenizer

* vllm not compatible with transformers v5

* setter for chat_template w mistral too

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
Wing Lian
2026-01-27 17:08:24 -05:00
committed by GitHub
parent a531e9d946
commit fc4e37920b
74 changed files with 262 additions and 309 deletions

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -24,7 +24,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
@@ -42,7 +41,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(

View File

@@ -14,8 +14,6 @@ from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from huggingface_hub import split_torch_state_dict_into_shards
@@ -40,17 +38,15 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
save under `save_path` as `model.safetensors`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
@@ -76,11 +72,7 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
@@ -98,19 +90,12 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
if index is not None:
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -123,13 +108,11 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
Note: this is a CPU-bound process.
@@ -138,8 +121,6 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
@@ -177,7 +158,7 @@ def merge_fsdp_weights(
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization
checkpoint_dir_, output_path
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
@@ -210,7 +191,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()

View File

@@ -102,12 +102,10 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
@@ -121,7 +119,7 @@ def do_quantize(
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
model.push_to_hub(hub_model_id)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_steps: int | float = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
@@ -230,6 +230,10 @@ class TrainerBuilderBase(abc.ABC):
else:
warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1:
warmup_steps = 2
@@ -242,7 +246,6 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -530,9 +533,7 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
@@ -545,6 +546,7 @@ class TrainerBuilderBase(abc.ABC):
arg_map = {
"dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -437,7 +437,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -738,43 +738,38 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -12,7 +12,6 @@ def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
@@ -22,7 +21,6 @@ def save_compressed_model(
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
@@ -34,7 +32,6 @@ def save_compressed_model(
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -26,7 +26,6 @@ from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
@@ -434,7 +433,7 @@ class ModelLoader:
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
self.model_config.model_type, AutoModelForImageTextToText
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
@@ -476,6 +475,7 @@ class ModelLoader:
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["dtype"] = self.cfg.torch_dtype
is_ds_zero3 = is_deepspeed_zero3_enabled()
@@ -670,7 +670,7 @@ class ModelLoader:
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -788,6 +788,7 @@ class ModelLoader:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:

View File

@@ -220,13 +220,6 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,

View File

@@ -31,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
_patch_mistralcommontokenizer()

View File

@@ -111,7 +111,6 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -1,5 +1,5 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
Monkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template
"""
import importlib
@@ -12,11 +12,11 @@ LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
"""Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonBackend
# Get original source
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
@@ -41,7 +41,7 @@ def apply_mistral_tokenizer_image_patch():
)
# Load necessary imports from the module
module_name = MistralCommonTokenizer.__module__
module_name = MistralCommonBackend.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
@@ -79,7 +79,7 @@ def apply_mistral_tokenizer_image_patch():
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
MistralCommonBackend.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonBackend tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
LOG.warning("Could not find target code for MistralCommonBackend patching")

View File

@@ -155,7 +155,6 @@ class ReLoRACallback(TrainerCallback):
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
@@ -214,7 +213,7 @@ class ReLoRACallback(TrainerCallback):
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
model.model.save_pretrained(checkpoint_folder)
return control

View File

@@ -52,9 +52,15 @@ def patch_prepare_context_parallel_inputs() -> None:
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
# Use a separate namespace to capture the exec'd function
namespace = {}
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
exec(patched_source, namespace)
# Explicitly get the function from the namespace
axolotl_prepare_context_parallel_inputs = namespace[
"axolotl_prepare_context_parallel_inputs"
]
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)

View File

@@ -14,7 +14,6 @@ from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -430,7 +429,7 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
def __init__(
self,
processor: Mistral3Processor,
processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
@@ -493,6 +492,8 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,

View File

@@ -150,6 +150,8 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
tokenize=True,
return_dict=False,
**chat_template_kwargs,
)

View File

@@ -135,16 +135,13 @@ def setup_reference_model(
return model_ref
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):
"""
Set up signal handler for graceful termination.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
@@ -152,9 +149,7 @@ def setup_signal_handler(
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
_model.save_pretrained(cfg.output_dir)
cleanup_distributed()
sys.exit(0)
@@ -219,7 +214,6 @@ def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
@@ -228,7 +222,6 @@ def save_trained_model(
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
@@ -283,7 +276,6 @@ def save_trained_model(
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=merged_path,
safe_serialization=True,
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
@@ -330,11 +322,9 @@ def save_trained_model(
pass
elif cfg.local_rank == 0:
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
trainer.model.save_pretrained(cfg.output_dir)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(cfg.output_dir)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
@@ -344,7 +334,6 @@ def save_trained_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
@@ -449,7 +438,6 @@ def handle_untrained_tokens_fix(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
@@ -459,7 +447,6 @@ def handle_untrained_tokens_fix(
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
@@ -483,9 +470,7 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
model.save_pretrained(str(Path(cfg.output_dir)))
def setup_model_and_trainer(
@@ -582,15 +567,12 @@ def train(
) = setup_model_and_trainer(cfg, dataset_meta)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model, safe_serialization)
setup_signal_handler(cfg, model)
setup_model_card(cfg)
# Execute the training
@@ -602,7 +584,7 @@ def train(
torch.cuda.empty_cache()
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
save_trained_model(cfg, trainer, model)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)

View File

@@ -7,7 +7,11 @@ from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
try:
from transformers.tokenization_python import PreTrainedTokenizer
except ImportError:
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.utils.distributed import is_main_process

View File

@@ -7,11 +7,11 @@ import numpy as np
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
from torch import Tensor
from transformers.tokenization_mistral_common import MistralCommonTokenizer
from transformers.tokenization_mistral_common import MistralCommonBackend
from transformers.tokenization_utils_base import VERY_LARGE_INTEGER
class HFMistralTokenizer(MistralCommonTokenizer):
class HFMistralTokenizer(MistralCommonBackend):
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
@@ -37,11 +37,19 @@ class HFMistralTokenizer(MistralCommonTokenizer):
def name_or_path(self) -> str:
return self._name_or_path
@name_or_path.setter
def name_or_path(self, name_or_path: str) -> None:
self._name_or_path = name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return "[This is a dummy chat template]"
@chat_template.setter
def chat_template(self, chat_template: str | None) -> None:
pass
def _set_mode(self, mode: ValidationMode):
"""Set the mode of the MistralRequestValidator.
@@ -133,7 +141,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
r"""
Patched fn to pass `name_or_path` and remove extra kwargs.
Instantiate a `MistralCommonTokenizer` from a predefined
Instantiate a `MistralCommonBackend` from a predefined
tokenizer.
Args:
@@ -142,7 +150,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
- A path to a *directory* containing the tokenizer config, for instance saved
using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g.,
using the [`MistralCommonBackend.tokenization_mistral_common.save_pretrained`] method, e.g.,
`./my_model_directory/`.
mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
Validation mode for the `MistralTokenizer` tokenizer.
@@ -154,7 +162,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
exist.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
when running `hf auth login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
revision (`str`, *optional*, defaults to `"main"`):
@@ -179,12 +187,12 @@ class HFMistralTokenizer(MistralCommonTokenizer):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
kwargs (additional keyword arguments, *optional*):
Not supported by `MistralCommonTokenizer.from_pretrained`.
Not supported by `MistralCommonBackend.from_pretrained`.
Will raise an error if used.
"""
if init_inputs:
raise ValueError(
"`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`."
"`init_inputs` are not supported by `MistralCommonBackend.from_pretrained`."
)
# Delete trust_remote_code as it does nothing
@@ -196,7 +204,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
# Handle kwargs and AutoTokenizer case
if kwargs and not kwargs.keys() == {"_from_auto"}:
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonBackend.from_pretrained`."
)
if not os.path.isfile(pretrained_model_name_or_path):

View File

@@ -4,7 +4,7 @@ FSDP Configuration Schema
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
class FSDPConfig(BaseModel):
@@ -12,6 +12,11 @@ class FSDPConfig(BaseModel):
FSDP Configuration Schema
"""
fsdp_version: int | None = Field(
validation_alias=AliasChoices("fsdp_version", "version"),
default=None,
json_schema_extra={"description": "FSDP version"},
)
activation_checkpointing: bool | None = Field(
default=None,
description="Enable activation checkpointing to reduce memory usage during forward passes",

View File

@@ -123,10 +123,22 @@ class ModelOutputConfig(BaseModel):
save_safetensors: bool | None = Field(
default=True,
json_schema_extra={
"description": "Save model as safetensors (require safetensors package). Default True"
"description": "Whether to save the model using safetensors format. Defaults to True."
},
)
@field_validator("save_safetensors")
@classmethod
def validate_save_safetensors(cls, v):
if v is False:
raise ValueError(
"save_safetensors=False is not supported in Transformers V5. "
"Transformers V5 always uses safetensors format for model serialization. "
"This field is deprecated and will be removed in a future version."
)
# Allow None and True, will default to True if None
return True if v is None else v
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""

View File

@@ -900,6 +900,43 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
fsdp_version = data.get("fsdp_version", None)
if not fsdp_version and fsdp_config and fsdp_config.get("version"):
fsdp_cfg_version = fsdp_config.pop("version")
data["fsdp_version"] = fsdp_cfg_version
data["fsdp_config"]["fsdp_version"] = fsdp_cfg_version
elif not fsdp_version and fsdp_config and fsdp_config.get("fsdp_version"):
data["fsdp_version"] = fsdp_config.get("fsdp_version")
if fsdp_version and fsdp_config and not fsdp_config.get("fsdp_version"):
data["fsdp_config"]["fsdp_version"] = fsdp_version
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
@@ -1001,40 +1038,6 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
class SystemValidationMixin:
"""Validation methods related to system and hardware configuration."""