Compare commits
1 Commits
revert-mul
...
lora-quant
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a22d16842 |
@@ -55,13 +55,16 @@ def dequantize(
|
|||||||
target_device = W.device
|
target_device = W.device
|
||||||
|
|
||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
|
nested = False
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
blocksize = quant_state.blocksize
|
blocksize = quant_state.blocksize
|
||||||
offset = quant_state.offset.to(target_device)
|
if quant_state.nested:
|
||||||
|
nested = True
|
||||||
|
offset = quant_state.offset.to(target_device)
|
||||||
state2 = quant_state.state2
|
state2 = quant_state.state2
|
||||||
absmax2 = state2.absmax.to(target_device)
|
absmax2 = state2.absmax.to(target_device)
|
||||||
code2 = state2.code.to(target_device)
|
code2 = state2.code.to(target_device)
|
||||||
@@ -115,7 +118,8 @@ def dequantize(
|
|||||||
ctypes.c_int(n_elements_absmax),
|
ctypes.c_int(n_elements_absmax),
|
||||||
)
|
)
|
||||||
|
|
||||||
out_absmax += offset
|
if nested:
|
||||||
|
out_absmax += offset
|
||||||
|
|
||||||
# Choose appropriate dequantization function
|
# Choose appropriate dequantization function
|
||||||
fx = (
|
fx = (
|
||||||
|
|||||||
Reference in New Issue
Block a user