diff --git a/examples/dbrx/16bit-lora.yaml b/examples/dbrx/16bit-lora.yaml index e5e3ea921..23c804584 100644 --- a/examples/dbrx/16bit-lora.yaml +++ b/examples/dbrx/16bit-lora.yaml @@ -22,6 +22,7 @@ wandb_watch: wandb_name: wandb_log_model: +qlora_fsdp_alt_loader: true adapter: lora lora_model_dir: lora_r: 8 diff --git a/examples/dbrx/8bit-lora.yaml b/examples/dbrx/8bit-lora.yaml index 89e24db05..109a27144 100644 --- a/examples/dbrx/8bit-lora.yaml +++ b/examples/dbrx/8bit-lora.yaml @@ -22,6 +22,7 @@ wandb_watch: wandb_name: wandb_log_model: +qlora_fsdp_alt_loader: true adapter: lora lora_model_dir: lora_r: 8 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d99155ac2..10ae3ac9e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -188,6 +188,7 @@ class LoraConfig(BaseModel): peft_use_dora: Optional[bool] = None peft_use_rslora: Optional[bool] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None + qlora_fsdp_alt_loader: Optional[bool] = None lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index 65f23b9e0..d225e9094 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -70,6 +70,7 @@ def load_and_quantize( to_meta: bool = False, verbose: bool = False, quant_method: str = "bnb", + is_dora: bool = False, ): """ Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. @@ -108,6 +109,12 @@ def load_and_quantize( # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This # workaround quantizes Params4bit to initialize quant_state on all ranks, then # replaces Params4bit's data with a meta tensor to free memory on non-rank 0. + if is_dora: + setattr( + submodule, + "dora_scale", + value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"), + ) value = type(param)( value.to(device=device, dtype=dtype).data, **param.__dict__ ).cuda(device) @@ -177,6 +184,7 @@ def load_sharded_model_quant( with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, + attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access trust_remote_code=cfg.trust_remote_code, ) if hasattr(model, "transformer"): @@ -249,6 +257,7 @@ def load_sharded_model_quant( to_meta=(low_memory and cfg.local_rank != 0), verbose=verbose, quant_method=quant_method, + is_dora=cfg.peft_use_dora, ) if cfg.local_rank == 0 and verbose: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e75..11fbc0d6a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -34,6 +34,7 @@ from transformers import ( # noqa: F401 PreTrainedTokenizerBase, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.quantizers import AutoHfQuantizer from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( @@ -568,7 +569,7 @@ def load_model( elif ( qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and cfg.model_config_type == "dbrx" + and cfg.qlora_fsdp_alt_loader ): quant_storage = cfg.torch_dtype model = load_sharded_model_quant( @@ -577,6 +578,11 @@ def load_model( cfg, quant_storage=quant_storage, ) + if model_kwargs["quantization_config"]: + hf_quantizer = AutoHfQuantizer.from_config( + model_kwargs["quantization_config"] + ) + model.hf_quantizer = hf_quantizer skip_move_to_device = True elif ( model_config.model_type == "llama" @@ -1003,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16): module.to(dtype) except AttributeError: pass + for name, param in model.named_parameters(): + try: + if param.data.dtype != dtype: + print(f"Converting module {name}: {param.data.dtype} -> {dtype}") + param.data = param.data.to(dtype) + except AttributeError: + pass