handle empty offset for quant state

This commit is contained in:
Wing Lian
2025-05-01 13:01:00 -04:00
parent fee3c13bb5
commit 1a22d16842

View File

@@ -55,13 +55,16 @@ def dequantize(
target_device = W.device target_device = W.device
# Extract quantization state # Extract quantization state
nested = False
if not isinstance(quant_state, list): if not isinstance(quant_state, list):
# New style quant_state class # New style quant_state class
absmax = quant_state.absmax.to(target_device) absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape shape = quant_state.shape
dtype = quant_state.dtype dtype = quant_state.dtype
blocksize = quant_state.blocksize blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device) if quant_state.nested:
nested = True
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2 state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device) absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device) code2 = state2.code.to(target_device)
@@ -115,7 +118,8 @@ def dequantize(
ctypes.c_int(n_elements_absmax), ctypes.c_int(n_elements_absmax),
) )
out_absmax += offset if nested:
out_absmax += offset
# Choose appropriate dequantization function # Choose appropriate dequantization function
fx = ( fx = (