diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index b61603fbc..bff9a2bd1 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -55,13 +55,16 @@ def dequantize( target_device = W.device # Extract quantization state + nested = False if not isinstance(quant_state, list): # New style quant_state class absmax = quant_state.absmax.to(target_device) shape = quant_state.shape dtype = quant_state.dtype blocksize = quant_state.blocksize - offset = quant_state.offset.to(target_device) + if quant_state.nested: + nested = True + offset = quant_state.offset.to(target_device) state2 = quant_state.state2 absmax2 = state2.absmax.to(target_device) code2 = state2.code.to(target_device) @@ -115,7 +118,8 @@ def dequantize( ctypes.c_int(n_elements_absmax), ) - out_absmax += offset + if nested: + out_absmax += offset # Choose appropriate dequantization function fx = (