diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 63567ed40..61a9f8fad 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -11,12 +11,11 @@ from accelerate.logging import get_logger from datasets import Dataset from transformers.trainer import Trainer +from axolotl.telemetry.errors import send_errors from axolotl.train import ( TrainDatasetMeta, setup_model_and_tokenizer, ) -from axolotl.telemetry.errors import send_errors -from axolotl.train import TrainDatasetMeta from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.trainer import setup_trainer diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 26e3313de..5517ff50a 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,6 +19,7 @@ from peft import ( from transformers import PreTrainedModel from axolotl.loaders.utils import get_linear_embedding_layers +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 64f03391e..ed1e8bbf2 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -46,6 +46,7 @@ from axolotl.loaders.utils import ( load_model_config, ) from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.telemetry.errors import send_errors from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index cfe63d5de..5d1c36618 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -8,6 +8,7 @@ from transformers import ( PreTrainedTokenizerBase, ) +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index b1d95bed9..cfc5da42c 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -12,6 +12,7 @@ from transformers import ( from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN +from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.distributed import ( barrier, diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py index 7650ec7c9..c788d3174 100644 --- a/src/axolotl/telemetry/callbacks.py +++ b/src/axolotl/telemetry/callbacks.py @@ -59,12 +59,14 @@ class TelemetryCallback(TrainerCallback): self.telemetry_manager.send_event( event_type="train-end", properties={ - "loss": state.log_history[-1].get("loss", 0) - if state.log_history - else None, - "learning_rate": state.log_history[-1].get("learning_rate", 0) - if state.log_history - else None, + "loss": ( + state.log_history[-1].get("loss", 0) if state.log_history else None + ), + "learning_rate": ( + state.log_history[-1].get("learning_rate", 0) + if state.log_history + else None + ), } | self.tracker.metrics.to_dict(), ) diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index e03fc3a61..b8aa5c2ad 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -307,9 +307,11 @@ class TelemetryManager: gpu_info.append( { "name": torch.hip.get_device_name(i), - "memory": torch.hip.get_device_properties(i).total_memory - if hasattr(torch.hip, "get_device_properties") - else None, + "memory": ( + torch.hip.get_device_properties(i).total_memory + if hasattr(torch.hip, "get_device_properties") + else None + ), } ) diff --git a/src/axolotl/telemetry/runtime_metrics.py b/src/axolotl/telemetry/runtime_metrics.py index 229c74edf..0b4be85b3 100644 --- a/src/axolotl/telemetry/runtime_metrics.py +++ b/src/axolotl/telemetry/runtime_metrics.py @@ -202,8 +202,8 @@ class RuntimeMetricsTracker: memory_used = self._get_allocated_memory() for i, memory in memory_used.items(): memory_metrics[f"gpu_{i}_memory_bytes"] = memory - memory_metrics[ - f"gpu_{i}_peak_memory_bytes" - ] = self.metrics.peak_gpu_memory.get(i, 0) + memory_metrics[f"gpu_{i}_peak_memory_bytes"] = ( + self.metrics.peak_gpu_memory.get(i, 0) + ) return memory_metrics diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c0fb3c01a..d341e3c77 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -33,7 +33,6 @@ from axolotl.loaders import ( load_tokenizer, ) from axolotl.telemetry.errors import send_errors -from axolotl.telemetry.errors import send_errors from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault @@ -91,11 +90,11 @@ def setup_model_and_tokenizer( if model.generation_config is not None: model.generation_config.do_sample = True - TELEMETRY_MANAGER.track_event( + TELEMETRY_MANAGER.send_event( event_type="model-load", properties=model.config.to_dict() ) if peft_config: - TELEMETRY_MANAGER.track_event( + TELEMETRY_MANAGER.send_event( event_type="peft-config-load", properties=peft_config.to_dict() ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py deleted file mode 100644 index 00d70c54c..000000000 --- a/src/axolotl/utils/models.py +++ /dev/null @@ -1,1415 +0,0 @@ -"""Module for models and model loading""" - -# pylint: disable=too-many-lines -import gc -import importlib -import logging -import math -import os -import types -from functools import cached_property -from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 - -import addict -import bitsandbytes as bnb -import torch -import transformers -import transformers.modeling_utils -from accelerate import init_empty_weights -from bitsandbytes.nn import Params4bit -from peft import ( - LoftQConfig, - PeftConfig, - PeftModel, - PeftModelForCausalLM, - prepare_model_for_kbit_training, -) -from peft.tuners.lora import QuantLinear -from torch import nn -from transformers import ( # noqa: F401 - AddedToken, - AutoConfig, - AutoModelForCausalLM, - AutoModelForVision2Seq, - AutoProcessor, - AutoTokenizer, - AwqConfig, - BitsAndBytesConfig, - GPTQConfig, - LlavaForConditionalGeneration, - MllamaForConditionalGeneration, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, -) -from transformers.integrations.deepspeed import ( - HfTrainerDeepSpeedConfig, - is_deepspeed_zero3_enabled, -) - -from axolotl.common.architectures import MOE_ARCH_BLOCK -from axolotl.models.mamba import fix_mamba_attn_for_loss -from axolotl.monkeypatch.multipack import ( - SUPPORTED_MULTIPACK_MODEL_TYPES, - patch_for_multipack, -) -from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN -from axolotl.telemetry.errors import send_errors -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import get_chat_template_from_config -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import get_device_count, get_device_type, zero_only -from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper -from axolotl.utils.lora_embeddings import get_linear_embedding_layers -from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant - -LOG = logging.getLogger("axolotl") - - -# copied from accelerator.FullyShardedDataParallelPlugin -def get_module_class_from_name(module, name): - """ - Gets a class from a module by its name. - - Args: - module (`torch.nn.Module`): The module to get the class from. - name (`str`): The name of the class. - """ - modules_children = list(module.children()) - if module.__class__.__name__ == name: - return module.__class__ - - if len(modules_children) == 0: - return None - - for child_module in modules_children: - module_class = get_module_class_from_name(child_module, name) - if module_class is not None: - return module_class - - return None - - -def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): - if cfg.is_multimodal: - model_config = model_config.text_config - - quant_config_exists = ( - hasattr(model_config, "quantization_config") - and model_config.quantization_config - ) - quant_config_method_is_gptq = ( - quant_config_exists - and "quant_method" in model_config.quantization_config - and model_config.quantization_config["quant_method"] == "gptq" - ) - - if cfg.gptq and not quant_config_method_is_gptq: - raise ValueError( - "model_config.quantization_config is not set or quant_method is not set to gptq. " - "Please make sure to point to a GPTQ model." - ) - - if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit: - raise ValueError( - "model_config.quantization_config is set but `gptq` flag is not. " - "Please use the `gptq` flag to train quantized model or point to a non-quantized model." - ) - - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - if ( - cfg.adapter - and cfg.tokens - and ( - not cfg.lora_modules_to_save - or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) - ) - ): - lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) - raise ValueError( - f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`." - ) - - -def load_model_config(cfg): - model_config_name = cfg.base_model_config or cfg.base_model - if not model_config_name and cfg.tokenizer_config: - model_config_name = cfg.tokenizer_config - trust_remote_code = cfg.trust_remote_code is True - config_kwargs = {} - if cfg.revision_of_model: - config_kwargs["revision"] = cfg.revision_of_model - if cfg.num_labels: - # num_labels is used to initialize classifier models - config_kwargs["num_labels"] = cfg.num_labels - try: - model_config = AutoConfig.from_pretrained( - model_config_name, - trust_remote_code=trust_remote_code, - **config_kwargs, - ) - except ValueError as err: - if "mamba" in model_config_name: - return addict.Dict( - { - "model_type": "mamba", - } - ) - raise err - - if cfg.overrides_of_model_config: - for key, val in cfg.overrides_of_model_config.items(): - setattr(model_config, key, val) - - check_model_config(cfg, model_config) - - return model_config - - -@send_errors -def load_tokenizer(cfg): - model_config = load_model_config(cfg) - tokenizer_kwargs = {} - use_fast = True # this is the default - - if cfg.tokenizer_use_fast is not None: - use_fast = cfg.tokenizer_use_fast - if cfg.tokenizer_legacy is not None: - # True is the default w/ https://github.com/huggingface/transformers/pull/25224 - tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - - tokenizer_cls = AutoTokenizer - if cfg.tokenizer_type: - tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - - tokenizer = tokenizer_cls.from_pretrained( - cfg.tokenizer_config, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - - if ( - tokenizer.__class__.__name__ - in [ - "LlamaTokenizer", - "LlamaTokenizerFast", - "CodeLlamaTokenizer", - "CodeLlamaTokenizerFast", - ] - and hasattr(tokenizer, "pad_token") - and not tokenizer.pad_token - ): - # set a pad_token, but use eos_token so we don't add a new token - tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: - tokenizer.padding_side = "left" - - # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: - token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] - for attr_name in token_ids: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, tokenizer.eod_id) - - token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] - for attr_name in token_names: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, "<|endoftext|>") - - additional_special_tokens = None - if cfg.special_tokens: - special_tokens = cfg.special_tokens.to_dict() - additional_special_tokens = special_tokens.pop( - "additional_special_tokens", None - ) - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - for k, val in special_tokens.items(): - # check if new special token is not already in tokenizer and - # is adapter training to make sure lora_modules_to_save is set - # pylint: disable=too-many-boolean-expressions - if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) - and cfg.adapter - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in lora_modules_to_save - ) - ) - and k != "pad_token" - ): - lora_modules_to_save = ", ".join( - [f"`{x}`" for x in lora_modules_to_save] - ) - raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." - ) - - tokenizer.add_special_tokens( - {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} - ) - - # If we add bos_token and eos_token, we need to update the post processor to - # handle them correctly. - # https://github.com/huggingface/transformers/pull/24132 - bos_or_eos_in_special_tokens = ( - "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens - ) - if ( - tokenizer.__class__.__name__ - in ( - "LlamaTokenizerFast", - "CodeLlamaTokenizerFast", - ) - and bos_or_eos_in_special_tokens - ): - tokenizer.update_post_processor() - - if cfg.tokens: - tokenizer.add_tokens( - [ - AddedToken(token, rstrip=False, lstrip=False, normalized=False) - for token in cfg.tokens - ] - ) - - # Additional special tokens are a List, and need to be treated differently than regular special - # tokens. We add them after we have called `add_tokens` in case these additional special tokens - # are new tokens. - # - # Usage: - # - # ```py - # special_tokens: - # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] - # ``` - if additional_special_tokens is not None: - tokenizer.add_special_tokens( - {"additional_special_tokens": additional_special_tokens} - ) - - with zero_only(): - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - if cfg.chat_template: - chat_template_string = get_chat_template_from_config( - cfg=cfg, - tokenizer=tokenizer, - ) - if cfg.default_system_message and cfg.chat_template == "chatml": - chat_template_string = chat_template_string.replace( - "You are a helpful assistant.", cfg.default_system_message - ) - - tokenizer.chat_template = chat_template_string - else: - LOG.info( - "No Chat template selected. Consider adding a chat template for easier inference." - ) - return tokenizer - - -@send_errors -def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): - processor_kwargs: Dict[str, Any] = {} # do we actually need this? - - processor_cls = AutoProcessor - if cfg.processor_type: - processor_cls = getattr(transformers, cfg.processor_type) - - processor = processor_cls.from_pretrained( - cfg.processor_config, - trust_remote_code=cfg.trust_remote_code or False, - tokenizer=tokenizer, - **processor_kwargs, - ) - - return processor - - -class ModelLoader: - """ - ModelLoader: managing all the config and monkey patches while loading model - """ - - def __init__( - self, - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument - ) -> None: - self.cfg = cfg - self.tokenizer = tokenizer - self.inference: bool = inference - self.reference_model: bool = reference_model - - # init model kwargs - self.model_kwargs: Dict[str, Any] = {} - if cfg.overrides_of_model_kwargs: - for key, val in cfg.overrides_of_model_kwargs.items(): - self.model_kwargs[key] = val - - # init model - self.model: PreTrainedModel - self.base_model = cfg.base_model - self.model_type = cfg.type_of_model - - # init model config - self.model_config = load_model_config(cfg) - if cfg.is_multimodal: - self.text_model_config = self.model_config.text_config - else: - self.text_model_config = self.model_config - - self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name - - def apply_patches(self) -> None: - # load any patches from plugins - from axolotl.integrations.base import PluginManager - - plugin_manager = PluginManager.get_instance() - plugin_manager.pre_model_load(self.cfg) - - if self.cfg.adapter: - from axolotl.monkeypatch.transformers_fa_utils import ( - patch_fa_peft_integration, - ) - - patch_fa_peft_integration() - - if self.cfg.gradient_checkpointing == "unsloth": - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper - - if self.cfg.flash_attention: - self.patch_attention() - - if self.cfg.sample_packing and self.cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) - - if ( - self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and self.cfg.flash_attention - and self.cfg.sample_packing - ): - if "auto_map" in self.model_config: - try: - auto_map_config = self.model_config["auto_map"] - except TypeError: - auto_map_config = self.model_config.auto_map - has_remote_code = "AutoModelForCausalLM" in auto_map_config - else: - has_remote_code = False - - if has_remote_code and self.cfg.trust_remote_code is False: - # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled - has_remote_code = self.cfg.trust_remote_code - patch_for_multipack( - self.cfg.model_config_type, - model_name=self.cfg.base_model, - has_remote_code=has_remote_code, - ) - - if self.cfg.is_llama_derived_model: - self.patch_loss_llama() - elif self.cfg.is_llama_derived_model: - self.patch_llama_derived_model() - - if ( - self.cfg.model_config_type == "mistral" - and self.cfg.flash_attn_cross_entropy_loss - ): - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - patch_mistral_cross_entropy, - ) - - patch_mistral_cross_entropy() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora - - patch_self_attn_lora(self.cfg) - - def patch_attention(self) -> None: - if hasattr(self.model_config, "model_type"): - if self.model_config.model_type == "mllama" and self.cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama - - patch_mllama() - - if self.model_config.model_type == "btlm": - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, - ) - - replace_btlm_attn_with_flash_attn(self.cfg.base_model) - - if ( - self.model_config.model_type == "stablelm_epoch" - and self.cfg.sample_packing - ): - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, - ) - - replace_stablelm_attn_with_flash_attn(self.cfg.base_model) - - @cached_property - def has_flash_attn(self) -> bool: - """Check if flash attention is installed""" - return importlib.util.find_spec("flash_attn") is not None - - def patch_loss_llama(self) -> None: - """Patch loss functions and other optimizations""" - if self.has_flash_attn: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_fa_llama_cross_entropy, - patch_llama_rms_norm, - ) - - if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: - patch_fa_llama_cross_entropy() - elif self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") - - if self.cfg.flash_attn_rms_norm and self.has_flash_attn: - patch_llama_rms_norm() - elif self.cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() - - def patch_llama_derived_model(self) -> None: - """Modify all llama derived models in one block""" - self.patch_loss_llama() - - if self.cfg.flash_attention: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, - ) - - if self.cfg.sample_packing: - if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.s2_attention: - LOG.info("patching w/ flash-enabled, shifted-sparse attention") - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - use_shifted_sparse_attn=True, - ) - elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif self.cfg.sample_packing: - from axolotl.monkeypatch.llama_patch_multipack import ( - hijack_llama_prepare_4d_mask, - ) - - LOG.info("patching llama _prepare_4d_causal_attention_mask*") - hijack_llama_prepare_4d_mask() - elif self.cfg.s2_attention: - raise NotImplementedError( - "Shifted-sparse attention not currently implemented without flash attention." - ) - - def set_auto_model_loader(self) -> None: - """set self.AutoModelLoader - - default value: AutoModelForCausalLM (set at __init__) - - when using a multi modality model, self.AutoModelLoader should - be set according to model type of the model - """ - if self.cfg.is_multimodal: - if self.model_config.model_type == "llava": - self.AutoModelLoader = ( # pylint: disable=invalid-name - LlavaForConditionalGeneration - ) - elif self.model_config.model_type == "mllama": - self.AutoModelLoader = ( # pylint: disable=invalid-name - MllamaForConditionalGeneration - ) - else: - self.AutoModelLoader = ( - AutoModelForVision2Seq # pylint: disable=invalid-name - ) - - def set_device_map_config(self) -> None: - device_map = self.cfg.device_map - max_memory = self.cfg.max_memory - - if self.cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(self.cfg.gpu_memory_limit) + "GiB" - if isinstance(self.cfg.gpu_memory_limit, int) - else self.cfg.gpu_memory_limit - ) - - max_memory = {} - num_device = get_device_count() - for i in range(num_device): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map - - with init_empty_weights(): - model_canvas = self.AutoModelLoader.from_config( - self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=self.cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - self.model_kwargs["device_map"] = device_map - self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - - cur_device = get_device_type() - if "mps" in str(cur_device): - self.model_kwargs["device_map"] = "mps:0" - elif "npu" in str(cur_device): - self.model_kwargs["device_map"] = "npu:0" - - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - - if is_deepspeed_zero3_enabled(): - del self.model_kwargs["device_map"] - - def set_quantization_config(self) -> None: - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit - - if self.cfg.gptq: - if not hasattr(self.model_config, "quantization_config"): - LOG.warning( - "model config does not contain quantization_config information" - ) - else: - if self.cfg.gptq_disable_exllama is not None: - self.model_config.quantization_config[ - "disable_exllama" - ] = self.cfg.gptq_disable_exllama - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - if ( - self.cfg.adapter in ["qlora", "lora"] - and hasattr(self.model_config, "quantization_config") - and self.model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] - ): - if self.model_config.quantization_config["quant_method"] == "gptq": - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - elif self.model_config.quantization_config["quant_method"] == "awq": - self.model_kwargs["quantization_config"] = AwqConfig( - **self.model_config.quantization_config - ) - elif ( - self.model_config.quantization_config["quant_method"] == "bitsandbytes" - ): - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **self.model_config.quantization_config - ) - elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": self.cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - self.cfg.deepspeed or self.cfg.fsdp - ): - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if self.cfg.bnb_config_kwargs: - bnb_config.update(self.cfg.bnb_config_kwargs) - - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if self.cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in self.model_kwargs or self.cfg.gptq: - self.model_kwargs.pop("load_in_8bit", None) - self.model_kwargs.pop("load_in_4bit", None) - - def set_attention_config(self) -> None: - """ - sample packing uses custom FA2 patch - """ - if self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - - if self.cfg.low_cpu_mem_usage: - self.model_kwargs["low_cpu_mem_usage"] = True - - def build_model(self, qlora_fsdp) -> bool: - def _configure_zero3_memory_efficient_loading(): - """ - Set the deepspeed config to load the model into RAM first before moving to VRAM. - - We need to return hf_ds_cfg as it needs to exist before model loading. - """ - hf_ds_cfg = None - - if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": - hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) - hf_ds_cfg.fill_match( - "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size - ) - hf_ds_cfg.fill_match( - "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps - ) - hf_ds_cfg.fill_match( - "train_batch_size", - int(os.getenv("WORLD_SIZE", "1")) - * self.cfg.micro_batch_size - * self.cfg.gradient_accumulation_steps, - ) - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True - transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( - lambda: True - ) - - return hf_ds_cfg - - skip_move_to_device = False - if ( # pylint: disable=condition-evals-to-constant) - (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) - and not qlora_fsdp - and False - ): - self.model = load_sharded_model( - self.base_model, - self.model_config, - self.cfg, - torch_dtype=self.cfg.torch_dtype, - ) - skip_move_to_device = True - elif ( - qlora_fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and ( - self.cfg.model_config_type == "dbrx" - or self.cfg.qlora_sharded_model_loading - ) - ): - quant_storage = self.cfg.torch_dtype - quantization_config = hasattr( - self.model_config, "quantization_config" - ) and getattr(self.model_config, "quantization_config") - quantization_config = ( - quantization_config or self.model_kwargs["quantization_config"] - ) - if self.cfg.is_multimodal: - self.model_config.text_config = self.text_model_config - self.model = load_sharded_model_quant( - self.base_model, - self.model_config, - self.cfg, - quant_storage=quant_storage, - quantization_config=quantization_config, - ) - skip_move_to_device = True - elif ( - self.model_config.model_type == "llama" - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - if self.cfg.is_multimodal: - self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) - - # TODO (MengqingCao) split these patches seperately - if self.cfg.flash_attention and not self.inference: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - is_xformers_swiglu_available, - replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, - ) - - if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(self.model) - - if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(self.model) - elif self.model_type == "MambaLMHeadModel": - # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - - self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] - self.model_kwargs["device"] = torch.cuda.current_device() - del self.model_kwargs["torch_dtype"] - del self.model_kwargs["device_map"] - - self.model = MambaLMHeadModel.from_pretrained( - self.base_model, - **self.model_kwargs, - ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.is_multimodal: - self.model_config.text_config = self.text_model_config - if self.cfg.gptq: - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this - # when training starts - if ( - hasattr(self.text_model_config, "max_seq_len") - and self.text_model_config.max_seq_len - and self.cfg.sequence_len > self.text_model_config.max_seq_len - ): - self.text_model_config.max_seq_len = self.cfg.sequence_len - LOG.warning(f"increasing context length to {self.cfg.sequence_len}") - elif ( - hasattr(self.text_model_config, "max_sequence_length") - and self.text_model_config.max_sequence_length - and self.cfg.sequence_len > self.text_model_config.max_sequence_length - ): - self.text_model_config.max_sequence_length = self.cfg.sequence_len - LOG.warning(f"increasing context length to {self.cfg.sequence_len}") - if self.cfg.gptq: - if self.cfg.is_multimodal: - self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - if ( - self.cfg.fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # disabling either of these two still leads to VRAM spike before setting back down - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - if self.cfg.is_multimodal: - self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - if is_deepspeed_zero3_enabled(): - skip_move_to_device = True - - return skip_move_to_device - - def ajust_model_config(self) -> None: - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "max_position_embeddings") - and self.model.config.max_position_embeddings - and self.cfg.sequence_len > self.model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" - ) - self.model.config.max_position_embeddings = self.cfg.sequence_len - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "bos_token_id") - and self.model.config.bos_token_id - and self.model.config.bos_token_id != self.tokenizer.bos_token_id - ): - self.model.config.bos_token_id = self.tokenizer.bos_token_id - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "eos_token_id") - and self.model.config.eos_token_id - and self.model.config.eos_token_id != self.tokenizer.eos_token_id - ): - self.model.config.eos_token_id = self.tokenizer.eos_token_id - - def set_z3_leaf_modules(self) -> None: - from deepspeed.utils import ( # pylint: disable=no-name-in-module - set_z3_leaf_modules, - ) - - if self.cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] - moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks - set_z3_leaf_modules( - self.model, - [ - get_module_class_from_name(self.model, module_name) - for module_name in moe_blocks - ], - ) - - def prepare_model(self, qlora_fsdp) -> None: - skip_prepare_model_for_kbit_training = False - if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True - - loftq_bits = ( - self.cfg.peft - and self.cfg.peft.loftq_config - and self.cfg.peft.loftq_config.loftq_bits - ) - if self.cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True - - if qlora_fsdp or ( - self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True - - if is_deepspeed_zero3_enabled(): - skip_prepare_model_for_kbit_training = True - - if ( - not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora"] - and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) - ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - self.model = prepare_model_for_kbit_training( - self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing - ) - - def convert_embedding_modules_dtype( - self, embedding_modules, dist_dtype, before_kbit_train_or_finetune - ) -> None: - for name, module in self.model.named_modules(): - if "norm" in name: - module.to(dist_dtype) - if before_kbit_train_or_finetune: - if name.endswith(".gate"): - module.to(dist_dtype) - if self.model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(dist_dtype) - - # TODO: Deprecate this. - def apply_unsloth_lora_patch(self) -> None: - if self.cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch - - integrate_lora_mlp_patch(self.model) - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch - - integrate_lora_patch(self.model, self.cfg) - if self.cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings - - integrate_rope_embeddings() - - def apply_lora_patch(self) -> None: - if ( - self.cfg.lora_mlp_kernel - or self.cfg.lora_qkv_kernel - or self.cfg.lora_o_kernel - ): - from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches - - apply_lora_kernel_patches(self.model, self.cfg) - - def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - self.apply_patches() - self.set_auto_model_loader() - self.set_device_map_config() - if self.cfg.revision_of_model: - self.model_kwargs["revision"] = self.cfg.revision_of_model - self.set_quantization_config() - self.set_attention_config() - - qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" - skip_move_to_device = False - - try: - skip_move_to_device = self.build_model(qlora_fsdp) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - self.model = self.model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(self.tokenizer) / 32) * 32 - if self.cfg.resize_token_embeddings_to_32x - else len(self.tokenizer) - ) - if hasattr(self.model, "get_input_embeddings") and ( - self.model.get_input_embeddings().num_embeddings < embeddings_len - or ( - self.model.get_input_embeddings().num_embeddings > embeddings_len - and self.cfg.shrink_embeddings - ) - ): - resize_kwargs = {} - if self.cfg.mean_resizing_embeddings is not None: - resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings - self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) - else: - self.model.tie_weights() - - self.ajust_model_config() - - # log device memory usage - if hasattr(self.model, "device") and self.model.device.type in ( - "cuda", - "mps", - "npu", - ): - log_gpu_memory_usage(LOG, "after model load", self.model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) - if not self.cfg.fsdp: - # FSDP doesn't like mixed Float and BFloat16 - self.convert_embedding_modules_dtype( - embedding_modules, - dist_dtype=torch.float32, - before_kbit_train_or_finetune=True, - ) - - if is_deepspeed_zero3_enabled(): - self.set_z3_leaf_modules() - - needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp - if self.cfg.adapter in ["lora", "qlora"]: - needs_fa2_dtype = True - if self.cfg.gradient_checkpointing: - self.model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs - ) - - self.prepare_model(qlora_fsdp) - - should_convert = ( - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) - or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass - ) - - if should_convert: - LOG.info("Converting modules to %s", self.cfg.torch_dtype) - self.convert_embedding_modules_dtype( - embedding_modules=embedding_modules, - dist_dtype=self.cfg.torch_dtype, - before_kbit_train_or_finetune=False, - ) - - # --------------------------------------------------------- - # load lora or adapter - # --------------------------------------------------------- - lora_config = None - if not self.reference_model or self.cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if ( - self.cfg.adapter - and self.cfg.rl in ["dpo", "ipo", "kto"] - and not self.cfg.merge_lora - ): - _, lora_config = load_lora( - self.model, self.cfg, inference=False, config_only=True - ) - else: - self.model, lora_config = load_adapter( - self.model, self.cfg, self.cfg.adapter - ) - - # --------------------------------------------------------- - # put model to accelerator - # --------------------------------------------------------- - if ( - self.cfg.ddp - and not self.cfg.load_in_8bit - and not (self.cfg.rl and self.cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") - - if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(self.model, "is_parallelizable", True) - setattr(self.model, "model_parallel", True) - - # --------------------------------------------------------- - # parameters that require gradient updates - # --------------------------------------------------------- - requires_grad = [] - for name, param in self.model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - if hasattr(self.model, "config"): - self.model.config.use_cache = False - - if self.cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer - - self.model = BetterTransformer.transform(self.model) - - if self.cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", self.model.device) - - self.apply_unsloth_lora_patch() - self.apply_lora_patch() - - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - - # TODO resume_from_checkpoint handling - return self.model, lora_config - - -@send_errors -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, - inference: bool = False, - reference_model: bool = False, - **kwargs, -) -> Tuple[PreTrainedModel, PeftConfig | None]: - """Load a model for a given configuration and tokenizer""" - loader = ModelLoader( - cfg, - tokenizer, - processor=processor, - inference=inference, - reference_model=reference_model, - **kwargs, - ) - return loader.load_model() - - -@send_errors -def load_adapter(model, cfg, adapter, inference=False): - # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - - if adapter is None: - return model, None - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: - return load_lora(model, cfg, inference=inference) - if adapter == "llama-adapter": - return load_llama_adapter(model, cfg) - - raise NotImplementedError(f"{adapter} peft adapter not available") - - -def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import AdaptionPromptConfig, get_peft_model - - peft_config = AdaptionPromptConfig( - adapter_layers=cfg.peft_adapter.layers, # layers (L) - adapter_len=cfg.peft_adapter.len, # prompt length (K) - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter") - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, peft_config) - - model.print_trainable_parameters() - - return model, peft_config - - -def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) - lora_module_names = set() - for name, module in model.named_modules(): - if ( - isinstance(module, cls) - or "Linear" in module.__class__.__name__ - and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) - ): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - embedding_modules = get_linear_embedding_layers(model.config.model_type) - output_embedding = embedding_modules[1] - if output_embedding in lora_module_names: # needed for 16-bit - lora_module_names.remove(output_embedding) - - return list(lora_module_names) - - -def setup_quantized_meta_for_peft(model: nn.Module): - """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" - - def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument - return self - - for param in model.parameters(): - if isinstance(param, Params4bit): - param.quant_state._orig_to = ( # pylint: disable=protected-access - param.quant_state.to - ) - param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) - - -def setup_quantized_peft_meta_for_training(model: nn.Module): - """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" - for param in model.parameters(): - if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): - param.quant_state.to = ( - param.quant_state._orig_to # pylint: disable=protected-access - ) - param.quant_state._orig_to = None # pylint: disable=protected-access - - -def load_lora(model, cfg, inference=False, config_only=False): - # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] - - from peft import LoraConfig, get_peft_model - - lora_target_modules = cfg.lora_target_modules or [] - - if cfg.lora_target_linear: - linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules_as_list = ( - lora_target_modules - if isinstance(lora_target_modules, list) - else [lora_target_modules] - ) - lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) - - lora_config_kwargs = {} - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if loftq_bits: - lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) - lora_config_kwargs["init_lora_weights"] = "loftq" - if cfg.peft_init_lora_weights: - lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights - if cfg.peft_use_dora: - lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") - if cfg.peft_use_rslora: - lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora - if cfg.peft_layer_replication: - lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - layers_to_transform=cfg.peft_layers_to_transform, - layers_pattern=cfg.peft_layers_pattern, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, - bias="none", - task_type="CAUSAL_LM", - **lora_config_kwargs, - ) - - if config_only: - return None, lora_config - - rank = int(os.environ.get("LOCAL_RANK", 0)) - - if ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_meta_for_peft(model) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} - if cfg.lora_on_cpu: - model_kwargs["max_memory"] = {"cpu": "256GiB"} - model_kwargs["device_map"] = {"": "cpu"} - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - is_trainable=(not inference), - **model_kwargs, - ) - else: - model = get_peft_model(model, lora_config) - - if rank == 0: - try: - model.print_trainable_parameters() - except AttributeError as exc: - LOG.warning( - "Exception caught during model.print_trainable_parameters(): %s", exc - ) - elif ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_peft_meta_for_training(model) - - return model, lora_config - - -def ensure_dtype(model, dtype=torch.bfloat16): - for name, module in model.named_modules(): - weight_mismatch = False - bias_mismatch = False - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass - try: - bias_mismatch = module.bias.dtype != dtype - except AttributeError: - pass - - if weight_mismatch: - print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") - if bias_mismatch: - print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") - if weight_mismatch or bias_mismatch: - module.to(dtype) diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py index 4324126e7..97d56a9c6 100644 --- a/tests/telemetry/test_callbacks.py +++ b/tests/telemetry/test_callbacks.py @@ -1,4 +1,5 @@ """Tests for telemetry callback module.""" + # pylint: disable=redefined-outer-name import time @@ -15,9 +16,9 @@ def calc_expected_metrics(step, last_step, current_time, last_time, start_time=9 time_diff = current_time - last_time step_diff = step - last_step return { - "steps_per_second": step_diff / time_diff - if time_diff > 0 and step_diff > 0 - else 0, + "steps_per_second": ( + step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0 + ), "time_since_last_report": time_diff, "elapsed_time": current_time - start_time, } diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py index 3d00c0f28..2f0510b21 100644 --- a/tests/telemetry/test_errors.py +++ b/tests/telemetry/test_errors.py @@ -1,4 +1,5 @@ """Tests for telemetry error utilities""" + # pylint: disable=redefined-outer-name from unittest.mock import MagicMock, patch diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py index 814f965a0..4549e22b0 100644 --- a/tests/telemetry/test_manager.py +++ b/tests/telemetry/test_manager.py @@ -1,4 +1,5 @@ """Tests for TelemetryManager class and utilities""" + # pylint: disable=redefined-outer-name,protected-access import os @@ -38,9 +39,13 @@ def telemetry_manager_class(): @pytest.fixture def manager(telemetry_manager_class, mock_whitelist): """Create a TelemetryManager instance with mocked dependencies""" - with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch( - "axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist - ), patch.dict(os.environ, {"RANK": "0"}): + with ( + patch("posthog.capture"), + patch("posthog.flush"), + patch("time.sleep"), + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"RANK": "0"}), + ): manager = telemetry_manager_class() # Manually enable for most tests manager.enabled = True @@ -49,8 +54,10 @@ def manager(telemetry_manager_class, mock_whitelist): def test_singleton_instance(telemetry_manager_class): """Test that TelemetryManager is a singleton""" - with patch("posthog.capture"), patch("time.sleep"), patch.dict( - os.environ, {"RANK": "0"} + with ( + patch("posthog.capture"), + patch("time.sleep"), + patch.dict(os.environ, {"RANK": "0"}), ): first = telemetry_manager_class() second = telemetry_manager_class() @@ -60,8 +67,10 @@ def test_singleton_instance(telemetry_manager_class): def test_telemetry_disabled_by_default(telemetry_manager_class): """Test that telemetry is disabled by default (opt-in)""" - with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch( - "logging.Logger.info" + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("time.sleep"), + patch("logging.Logger.info"), ): manager = telemetry_manager_class() assert not manager.enabled @@ -69,8 +78,9 @@ def test_telemetry_disabled_by_default(telemetry_manager_class): def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class): """Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch( - "time.sleep" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), + patch("time.sleep"), ): manager = telemetry_manager_class() assert manager.enabled @@ -78,8 +88,9 @@ def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class): def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): """Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}), patch( - "time.sleep" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}), + patch("time.sleep"), ): manager = telemetry_manager_class() assert not manager.enabled @@ -87,17 +98,21 @@ def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): """Test that telemetry is disabled when DO_NOT_TRACK=1""" - with patch.dict( - os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"} - ), patch("time.sleep"): + with ( + patch.dict( + os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"} + ), + patch("time.sleep"), + ): manager = telemetry_manager_class() assert not manager.enabled def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): """Test that telemetry is disabled for non-main processes""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}), patch( - "time.sleep" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}), + patch("time.sleep"), ): manager = telemetry_manager_class() assert not manager.enabled @@ -105,9 +120,11 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): def test_opt_in_info_displayed(telemetry_manager_class): """Test that opt-in info is displayed when telemetry is not configured""" - with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch( - "logging.Logger.warning" - ) as mock_warning, patch("time.sleep"): + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("logging.Logger.warning") as mock_warning, + patch("time.sleep"), + ): telemetry_manager_class() info_displayed = False for call in mock_warning.call_args_list: @@ -120,8 +137,9 @@ def test_opt_in_info_displayed(telemetry_manager_class): def test_is_whitelisted(telemetry_manager_class, mock_whitelist): """Test org whitelist functionality""" - with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), patch.dict( - os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"} + with ( + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), ): manager = telemetry_manager_class() @@ -150,8 +168,9 @@ def test_system_info_collection(manager): def test_send_event(telemetry_manager_class): """Test basic event sending""" - with patch("posthog.capture") as mock_capture, patch.dict( - os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"} + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), ): manager = telemetry_manager_class() @@ -171,8 +190,9 @@ def test_send_event(telemetry_manager_class): def test_send_system_info(telemetry_manager_class): """Test sending system info""" - with patch("posthog.capture") as mock_capture, patch.dict( - os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"} + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), ): manager = telemetry_manager_class() manager.send_system_info() @@ -183,8 +203,9 @@ def test_send_system_info(telemetry_manager_class): def test_redacted_properties(telemetry_manager_class): """Test path redaction in send_event method""" - with patch("posthog.capture") as mock_capture, patch.dict( - os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"} + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), ): manager = telemetry_manager_class() # Test with properties containing various paths and non-paths @@ -237,9 +258,10 @@ def test_disable_telemetry(manager): def test_exception_handling_during_send(manager): """Test that exceptions in PostHog are handled gracefully""" - with patch("posthog.capture", side_effect=Exception("Test error")), patch( - "logging.Logger.warning" - ) as mock_warning: + with ( + patch("posthog.capture", side_effect=Exception("Test error")), + patch("logging.Logger.warning") as mock_warning, + ): manager.send_event("test_event") warning_logged = False for call in mock_warning.call_args_list: diff --git a/tests/telemetry/test_runtime_metrics.py b/tests/telemetry/test_runtime_metrics.py index 5a6ef5a08..c8916e072 100644 --- a/tests/telemetry/test_runtime_metrics.py +++ b/tests/telemetry/test_runtime_metrics.py @@ -1,4 +1,5 @@ """Tests for runtime metrics telemetry module""" + # pylint: disable=redefined-outer-name from unittest.mock import MagicMock, patch