transformers v5 upgrade (#3272)
* Prepare for transformers v5 upgrade * fix hf cli * update for hf hub changes * fix tokenizer apply_chat_template args * remap include_tokens_per_second * fix tps * handle migration for warmup * use latest hf hub * Fix scan -> ls * fix import * fix for renaming of mistral common tokenizer -> backend * update for fixed tokenziation for llama * Skip phi35 tests for now * remove mistral patch fixed upstream in huggingface/transformers#41439 * use namespacing for patch * don't rely on sdist for e2e tests for now * run modal ci without waiting too * Fix dep for ci * fix imports * Fix fp8 check * fsdp2 fixes * fix version handling * update fsdp version tests for new v5 behavior * Fail multigpu tests after 3 failures * skip known v5 broken tests for now and cleanup * bump deps * unmark skipped test * re-enable test_fsdp_qlora_prequant_packed test * increase multigpu ci timeout * skip broken gemma3 test * reduce timout back to original 120min now that the hanging test is skipped * fix for un-necessary collator for pretraining with bsz=1 * fix: safe_serialization deprecated in transformers v5 rc01 (#3318) * torch_dtype deprecated * load model in float32 for consistency with tests * revert some test fixtures back * use hf cache ls instead of scan * don't strip fsdp_version more fdsp_Version fixes for v5 fix version in fsdp_config fix aliasing fix fsdp_version check check fsdp_version is 2 in both places * Transformers v5 rc2 (#3347) * bump dep * use latest fbgemm, grab model config as part of fixture, un-skip test * import AutoConfig * don't need more problematic autoconfig when specifying config.json manually * add fixtures for argilla ultrafeedback datasets * download phi4-reasoning * fix arg * update tests for phi fast tokenizer changes * use explicit model types for gemma3 --------- Co-authored-by: Wing Lian <wing@axolotl.ai> * fix: AutoModelForVision2Seq -> AutoModelForImageTextToText * chore: remove duplicate * fix: attempt fix gemma3 text mode * chore: lint * ga release of v5 * need property setter for name_or_path for mistral tokenizer * vllm not compatible with transformers v5 * setter for chat_template w mistral too --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -44,7 +44,7 @@ def check_user_token() -> bool:
|
||||
return bool(user_info)
|
||||
except LocalTokenNotFoundError:
|
||||
LOG.warning(
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
|
||||
@@ -24,7 +24,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
model = model.merge_and_unload(progressbar=True)
|
||||
@@ -42,7 +41,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
|
||||
@@ -14,8 +14,6 @@ from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
@@ -40,17 +38,15 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||
def _distributed_checkpoint_to_merged_weights(
|
||||
checkpoint_dir: Union[str, Path],
|
||||
save_path: str,
|
||||
safe_serialization: bool = False,
|
||||
max_shard_size: str = "5GB",
|
||||
) -> Path:
|
||||
"""
|
||||
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
|
||||
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
||||
save under `save_path` as `model.safetensors`.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory where distributed checkpoint is saved.
|
||||
save_path: Path to save model to.
|
||||
safe_serialization: Whether to save in safetensors format.
|
||||
max_shard_size: Max size of model shards to save.
|
||||
|
||||
Returns:
|
||||
@@ -76,11 +72,7 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
|
||||
state_dict[key] = value.to(torch.bfloat16)
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
@@ -98,19 +90,12 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
||||
|
||||
if safe_serialization:
|
||||
safe_save_file(
|
||||
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
else:
|
||||
torch.save(shard, os.path.join(save_path_, shard_file))
|
||||
safe_save_file(
|
||||
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
|
||||
if index is not None:
|
||||
save_index_file = (
|
||||
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
)
|
||||
save_index_file = os.path.join(save_path_, save_index_file)
|
||||
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as fout:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@@ -123,13 +108,11 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
safe_serialization: bool = False,
|
||||
remove_checkpoint_dir: bool = False,
|
||||
):
|
||||
"""
|
||||
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
|
||||
`safe_serialization` else `pytorch_model.bin`.
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
|
||||
|
||||
Note: this is a CPU-bound process.
|
||||
|
||||
@@ -138,8 +121,6 @@ def merge_fsdp_weights(
|
||||
The directory containing the FSDP checkpoints (can be either the model or optimizer).
|
||||
output_path (`str`):
|
||||
The path to save the merged checkpoint.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the merged weights with safetensors (recommended).
|
||||
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether to remove the checkpoint directory after merging.
|
||||
|
||||
@@ -177,7 +158,7 @@ def merge_fsdp_weights(
|
||||
if state.is_main_process:
|
||||
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
|
||||
save_path = _distributed_checkpoint_to_merged_weights(
|
||||
checkpoint_dir_, output_path, safe_serialization
|
||||
checkpoint_dir_, output_path
|
||||
)
|
||||
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
|
||||
if remove_checkpoint_dir:
|
||||
@@ -210,7 +191,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
|
||||
@@ -102,12 +102,10 @@ def do_quantize(
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
@@ -121,7 +119,7 @@ def do_quantize(
|
||||
hub_model_id.rstrip("-")
|
||||
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
model.push_to_hub(hub_model_id)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
if processor:
|
||||
processor.push_to_hub(hub_model_id)
|
||||
|
||||
@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
def _configure_warmup_and_logging(
|
||||
self, total_num_steps: int, training_args_kwargs: dict
|
||||
):
|
||||
warmup_steps = 0
|
||||
warmup_steps: int | float = 0
|
||||
warmup_ratio = 0.0
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
@@ -230,6 +230,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
else:
|
||||
warmup_ratio = 0.03
|
||||
|
||||
# transformers v5
|
||||
if warmup_ratio > 0.0 and warmup_steps == 0:
|
||||
warmup_steps = warmup_ratio
|
||||
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
@@ -242,7 +246,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
|
||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||
@@ -530,9 +533,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"loraplus_lr_ratio",
|
||||
"loraplus_lr_embedding",
|
||||
"output_dir",
|
||||
"save_safetensors",
|
||||
"save_only_model",
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
@@ -545,6 +546,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
"include_num_input_tokens_seen": "include_tokens_per_second",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
|
||||
@@ -437,7 +437,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
|
||||
self.cfg.micro_batch_size == 1 and is_eval is False
|
||||
):
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
|
||||
@@ -25,7 +25,7 @@ from torch.utils.data import (
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -738,43 +738,38 @@ class AxolotlTrainer(
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -12,7 +12,6 @@ def save_compressed_model(
|
||||
model: PreTrainedModel,
|
||||
output_dir: Union[str, bytes],
|
||||
trainer: Trainer,
|
||||
safe_serialization: bool = False,
|
||||
save_compressed: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ def save_compressed_model(
|
||||
model (PreTrainedModel): The model to be saved.
|
||||
output_dir (str or bytes): Path where the model files will be written.
|
||||
trainer (Trainer): Hugging Face Trainer for process synchronization.
|
||||
safe_serialization (bool): Use safe serialization if True.
|
||||
save_compressed (bool): Write compressed tensors if True.
|
||||
"""
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
@@ -34,7 +32,6 @@ def save_compressed_model(
|
||||
modify_save_pretrained(model)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=save_compressed,
|
||||
skip_sparsity_compression_stats=not save_compressed,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
@@ -434,7 +433,7 @@ class ModelLoader:
|
||||
"""
|
||||
if self.cfg.is_multimodal:
|
||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
self.model_config.model_type, AutoModelForImageTextToText
|
||||
)
|
||||
if isinstance(self.auto_model_loader, str):
|
||||
self.auto_model_loader = AutoModelForImageTextToText
|
||||
@@ -476,6 +475,7 @@ class ModelLoader:
|
||||
max_memory = None
|
||||
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
self.model_kwargs["dtype"] = self.cfg.torch_dtype
|
||||
|
||||
is_ds_zero3 = is_deepspeed_zero3_enabled()
|
||||
|
||||
@@ -670,7 +670,7 @@ class ModelLoader:
|
||||
Uses the selected loader when provided; otherwise falls back to the auto loader.
|
||||
"""
|
||||
loader = model_loader_class or self.auto_model_loader
|
||||
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
|
||||
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
|
||||
model = loader.from_config(
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -788,6 +788,7 @@ class ModelLoader:
|
||||
# Use auto model loader (handles gptq and default cases)
|
||||
model_loader_class = self.auto_model_loader
|
||||
|
||||
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
|
||||
if self.cfg.reinit_weights:
|
||||
self.model = self._load_model_from_config(model_loader_class)
|
||||
else:
|
||||
|
||||
@@ -220,13 +220,6 @@ class PatchManager:
|
||||
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "kimi_linear":
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_model,
|
||||
|
||||
@@ -31,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
||||
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
|
||||
|
||||
_patch_mistralcommontokenizer()
|
||||
|
||||
|
||||
@@ -111,7 +111,6 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
safe_serialization: Optional[bool] = None,
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
|
||||
Monkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -12,11 +12,11 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def apply_mistral_tokenizer_image_patch():
|
||||
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
"""Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion."""
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
|
||||
original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
@@ -41,7 +41,7 @@ def apply_mistral_tokenizer_image_patch():
|
||||
)
|
||||
|
||||
# Load necessary imports from the module
|
||||
module_name = MistralCommonTokenizer.__module__
|
||||
module_name = MistralCommonBackend.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Detect what needs to be imported
|
||||
@@ -79,7 +79,7 @@ def apply_mistral_tokenizer_image_patch():
|
||||
exec(patched_source, globals()) # nosec B102
|
||||
|
||||
# Replace the method
|
||||
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
|
||||
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
|
||||
MistralCommonBackend.apply_chat_template = patched_apply_chat_template
|
||||
LOG.info("Successfully applied MistralCommonBackend tensor conversion patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
|
||||
LOG.warning("Could not find target code for MistralCommonBackend patching")
|
||||
|
||||
@@ -155,7 +155,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"adapter",
|
||||
),
|
||||
safe_serialization=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
@@ -214,7 +213,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
|
||||
model.model.save_pretrained(checkpoint_folder)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -52,9 +52,15 @@ def patch_prepare_context_parallel_inputs() -> None:
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
|
||||
exec(patched_source, globals())
|
||||
# Use a separate namespace to capture the exec'd function
|
||||
namespace = {}
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
|
||||
exec(patched_source, namespace)
|
||||
|
||||
# Explicitly get the function from the namespace
|
||||
axolotl_prepare_context_parallel_inputs = namespace[
|
||||
"axolotl_prepare_context_parallel_inputs"
|
||||
]
|
||||
Trainer._original_prepare_context_parallel_inputs = (
|
||||
Trainer._prepare_context_parallel_inputs
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -430,7 +429,7 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: Mistral3Processor,
|
||||
processor,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
@@ -493,6 +492,8 @@ def get_processing_strategy(
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
|
||||
processing_kwargs = {
|
||||
"processor": processor,
|
||||
"chat_template": chat_template,
|
||||
|
||||
@@ -150,6 +150,8 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=False,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,16 +135,13 @@ def setup_reference_model(
|
||||
return model_ref
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The model to save on termination
|
||||
safe_serialization: Whether to use safe serialization when saving
|
||||
"""
|
||||
# ray workers don't have access to this signal
|
||||
if cfg.local_rank == 0 and not cfg.use_ray:
|
||||
@@ -152,9 +149,7 @@ def setup_signal_handler(
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
_model = model_weakref()
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
_model.save_pretrained(cfg.output_dir)
|
||||
|
||||
cleanup_distributed()
|
||||
sys.exit(0)
|
||||
@@ -219,7 +214,6 @@ def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Save the trained model according to configuration and training setup.
|
||||
@@ -228,7 +222,6 @@ def save_trained_model(
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The trainer object.
|
||||
model: The trained model to save.
|
||||
safe_serialization: Whether to use safe serialization.
|
||||
"""
|
||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||
|
||||
@@ -283,7 +276,6 @@ def save_trained_model(
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=merged_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
@@ -330,11 +322,9 @@ def save_trained_model(
|
||||
pass
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
trainer.model.save_pretrained(cfg.output_dir)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
@@ -344,7 +334,6 @@ def save_trained_model(
|
||||
model=model,
|
||||
output_dir=cfg.output_dir,
|
||||
trainer=trainer,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=cfg.llmcompressor.save_compressed,
|
||||
)
|
||||
|
||||
@@ -449,7 +438,6 @@ def handle_untrained_tokens_fix(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Apply fixes for untrained tokens if configured.
|
||||
@@ -459,7 +447,6 @@ def handle_untrained_tokens_fix(
|
||||
model: The model to apply fixes to.
|
||||
tokenizer: The tokenizer for token identification.
|
||||
train_dataset: The training dataset to use.
|
||||
safe_serialization: Whether to use safe serialization when saving.
|
||||
"""
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
@@ -483,9 +470,7 @@ def handle_untrained_tokens_fix(
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
model.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
|
||||
def setup_model_and_trainer(
|
||||
@@ -582,15 +567,12 @@ def train(
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
||||
|
||||
# Additional setup
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
setup_signal_handler(cfg, model)
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
@@ -602,7 +584,7 @@ def train(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the trained model and cleanup
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
save_trained_model(cfg, trainer, model)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
|
||||
)
|
||||
|
||||
@@ -7,7 +7,11 @@ from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
try:
|
||||
from transformers.tokenization_python import PreTrainedTokenizer
|
||||
except ImportError:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ import numpy as np
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
|
||||
from torch import Tensor
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
from transformers.tokenization_utils_base import VERY_LARGE_INTEGER
|
||||
|
||||
|
||||
class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
class HFMistralTokenizer(MistralCommonBackend):
|
||||
"""
|
||||
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
|
||||
and exposes HuggingFace API for special tokens.
|
||||
@@ -37,11 +37,19 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
def name_or_path(self) -> str:
|
||||
return self._name_or_path
|
||||
|
||||
@name_or_path.setter
|
||||
def name_or_path(self, name_or_path: str) -> None:
|
||||
self._name_or_path = name_or_path
|
||||
|
||||
@property
|
||||
def chat_template(self) -> str | None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return "[This is a dummy chat template]"
|
||||
|
||||
@chat_template.setter
|
||||
def chat_template(self, chat_template: str | None) -> None:
|
||||
pass
|
||||
|
||||
def _set_mode(self, mode: ValidationMode):
|
||||
"""Set the mode of the MistralRequestValidator.
|
||||
|
||||
@@ -133,7 +141,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
r"""
|
||||
Patched fn to pass `name_or_path` and remove extra kwargs.
|
||||
|
||||
Instantiate a `MistralCommonTokenizer` from a predefined
|
||||
Instantiate a `MistralCommonBackend` from a predefined
|
||||
tokenizer.
|
||||
|
||||
Args:
|
||||
@@ -142,7 +150,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
|
||||
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
||||
- A path to a *directory* containing the tokenizer config, for instance saved
|
||||
using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g.,
|
||||
using the [`MistralCommonBackend.tokenization_mistral_common.save_pretrained`] method, e.g.,
|
||||
`./my_model_directory/`.
|
||||
mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
|
||||
Validation mode for the `MistralTokenizer` tokenizer.
|
||||
@@ -154,7 +162,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
exist.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
when running `hf auth login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
@@ -179,12 +187,12 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
||||
tokenization process.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Not supported by `MistralCommonTokenizer.from_pretrained`.
|
||||
Not supported by `MistralCommonBackend.from_pretrained`.
|
||||
Will raise an error if used.
|
||||
"""
|
||||
if init_inputs:
|
||||
raise ValueError(
|
||||
"`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`."
|
||||
"`init_inputs` are not supported by `MistralCommonBackend.from_pretrained`."
|
||||
)
|
||||
|
||||
# Delete trust_remote_code as it does nothing
|
||||
@@ -196,7 +204,7 @@ class HFMistralTokenizer(MistralCommonTokenizer):
|
||||
# Handle kwargs and AutoTokenizer case
|
||||
if kwargs and not kwargs.keys() == {"_from_auto"}:
|
||||
raise ValueError(
|
||||
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
|
||||
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonBackend.from_pretrained`."
|
||||
)
|
||||
|
||||
if not os.path.isfile(pretrained_model_name_or_path):
|
||||
|
||||
@@ -4,7 +4,7 @@ FSDP Configuration Schema
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
|
||||
|
||||
class FSDPConfig(BaseModel):
|
||||
@@ -12,6 +12,11 @@ class FSDPConfig(BaseModel):
|
||||
FSDP Configuration Schema
|
||||
"""
|
||||
|
||||
fsdp_version: int | None = Field(
|
||||
validation_alias=AliasChoices("fsdp_version", "version"),
|
||||
default=None,
|
||||
json_schema_extra={"description": "FSDP version"},
|
||||
)
|
||||
activation_checkpointing: bool | None = Field(
|
||||
default=None,
|
||||
description="Enable activation checkpointing to reduce memory usage during forward passes",
|
||||
|
||||
@@ -123,10 +123,22 @@ class ModelOutputConfig(BaseModel):
|
||||
save_safetensors: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Save model as safetensors (require safetensors package). Default True"
|
||||
"description": "Whether to save the model using safetensors format. Defaults to True."
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("save_safetensors")
|
||||
@classmethod
|
||||
def validate_save_safetensors(cls, v):
|
||||
if v is False:
|
||||
raise ValueError(
|
||||
"save_safetensors=False is not supported in Transformers V5. "
|
||||
"Transformers V5 always uses safetensors format for model serialization. "
|
||||
"This field is deprecated and will be removed in a future version."
|
||||
)
|
||||
# Allow None and True, will default to True if None
|
||||
return True if v is None else v
|
||||
|
||||
|
||||
class SpecialTokensConfig(BaseModel):
|
||||
"""Special tokens configuration subset"""
|
||||
|
||||
@@ -900,6 +900,43 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
should_fix = False
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
should_fix = True
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
if should_fix:
|
||||
update_fsdp_config = {}
|
||||
for key, value in fsdp_config.items():
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||
else:
|
||||
update_fsdp_config[key] = value
|
||||
data["fsdp_config"] = update_fsdp_config
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
fsdp_version = data.get("fsdp_version", None)
|
||||
if not fsdp_version and fsdp_config and fsdp_config.get("version"):
|
||||
fsdp_cfg_version = fsdp_config.pop("version")
|
||||
data["fsdp_version"] = fsdp_cfg_version
|
||||
data["fsdp_config"]["fsdp_version"] = fsdp_cfg_version
|
||||
elif not fsdp_version and fsdp_config and fsdp_config.get("fsdp_version"):
|
||||
data["fsdp_version"] = fsdp_config.get("fsdp_version")
|
||||
if fsdp_version and fsdp_config and not fsdp_config.get("fsdp_version"):
|
||||
data["fsdp_config"]["fsdp_version"] = fsdp_version
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||
if (
|
||||
@@ -1001,40 +1038,6 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
should_fix = False
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
should_fix = True
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
if should_fix:
|
||||
update_fsdp_config = {}
|
||||
for key, value in fsdp_config.items():
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||
else:
|
||||
update_fsdp_config[key] = value
|
||||
data["fsdp_config"] = update_fsdp_config
|
||||
return data
|
||||
|
||||
|
||||
class SystemValidationMixin:
|
||||
"""Validation methods related to system and hardware configuration."""
|
||||
|
||||
Reference in New Issue
Block a user