Compare commits
1 Commits
devstral-s
...
fsdp-qdora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a7c56f018 |
@@ -22,6 +22,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
qlora_fsdp_alt_loader: true
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
qlora_fsdp_alt_loader: true
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
|
qlora_fsdp_alt_loader: Optional[bool] = None
|
||||||
|
|
||||||
lora_on_cpu: Optional[bool] = None
|
lora_on_cpu: Optional[bool] = None
|
||||||
gptq: Optional[bool] = None
|
gptq: Optional[bool] = None
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ def load_and_quantize(
|
|||||||
to_meta: bool = False,
|
to_meta: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
quant_method: str = "bnb",
|
quant_method: str = "bnb",
|
||||||
|
is_dora: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||||
@@ -108,6 +109,12 @@ def load_and_quantize(
|
|||||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||||
|
if is_dora:
|
||||||
|
setattr(
|
||||||
|
submodule,
|
||||||
|
"dora_scale",
|
||||||
|
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
|
||||||
|
)
|
||||||
value = type(param)(
|
value = type(param)(
|
||||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||||
).cuda(device)
|
).cuda(device)
|
||||||
@@ -177,6 +184,7 @@ def load_sharded_model_quant(
|
|||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
model_config,
|
model_config,
|
||||||
|
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||||
trust_remote_code=cfg.trust_remote_code,
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
)
|
)
|
||||||
if hasattr(model, "transformer"):
|
if hasattr(model, "transformer"):
|
||||||
@@ -249,6 +257,7 @@ def load_sharded_model_quant(
|
|||||||
to_meta=(low_memory and cfg.local_rank != 0),
|
to_meta=(low_memory and cfg.local_rank != 0),
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
quant_method=quant_method,
|
quant_method=quant_method,
|
||||||
|
is_dora=cfg.peft_use_dora,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and verbose:
|
if cfg.local_rank == 0 and verbose:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.quantizers import AutoHfQuantizer
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
@@ -568,7 +569,7 @@ def load_model(
|
|||||||
elif (
|
elif (
|
||||||
qlora_fsdp
|
qlora_fsdp
|
||||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
and cfg.model_config_type == "dbrx"
|
and cfg.qlora_fsdp_alt_loader
|
||||||
):
|
):
|
||||||
quant_storage = cfg.torch_dtype
|
quant_storage = cfg.torch_dtype
|
||||||
model = load_sharded_model_quant(
|
model = load_sharded_model_quant(
|
||||||
@@ -577,6 +578,11 @@ def load_model(
|
|||||||
cfg,
|
cfg,
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
)
|
)
|
||||||
|
if model_kwargs["quantization_config"]:
|
||||||
|
hf_quantizer = AutoHfQuantizer.from_config(
|
||||||
|
model_kwargs["quantization_config"]
|
||||||
|
)
|
||||||
|
model.hf_quantizer = hf_quantizer
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
elif (
|
elif (
|
||||||
model_config.model_type == "llama"
|
model_config.model_type == "llama"
|
||||||
@@ -1003,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16):
|
|||||||
module.to(dtype)
|
module.to(dtype)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
try:
|
||||||
|
if param.data.dtype != dtype:
|
||||||
|
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
|
||||||
|
param.data = param.data.to(dtype)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user