模型量化

2 minute read

Published:

模型量化分为后量化(Post-training Quantization, PTQ)和训练量化(Quantization aware training,QAT)。PTQ是直接将模型浮点权重转换为整数,QAT是将模型插入量化节点,之后再fine-tune调整模型的权重。

本文以权重量化为int4bit为例。

量化分为对称量化和非对称量化。对于量化为int4bit来说,对称量化是将int4bit的值域设为[-8, 7],即-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,这16个数。非对称量化是将int4bit的值域设为[0, 15],即0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15这16个数字。

以非对称量化为例,最简单的方法就是把 ”浮点权重的值域最大值换成15,浮点权重值域的最小值换成0“,如下图所示。

int4有{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}个数值表示,int4的每个数值间隔是1。所以对于中间的浮点数来说,对应到int4最简单的方法是将浮点数也切分成15份。下面介绍一种最简单的切分方法,将每一格的大小都设为一样的。只需要将浮点的值域除以15就可以,在浮点数上,每一格的大小也称为scale,如下图所示。

对于将浮点等分切块的方法,scale值是固定的。将每一个scale内的值映射到对应的int值上,如上图最左侧橘黄色的float值映射到int的橘黄色值,取四舍五入为0。float值上黄色的scale范围,映射到int的黄色值上,取四舍五入为5。

下面我们看一下具体的计算过程:

1.对称量化

假设浮点权重为weight_float = [0.1, 0.2, 1.2, 3, 2.1, -2.1, -3.5]

(1) 首先计算浮点权重绝对值的最大值 range_weight_float = max(abs(weight_float))

在这个例子中,range_weight_float=max(abs([0.1, 0.2, 1.2, 3, 2.1, -2.1, -3.5])) = 3.5

(2) 计算scale

scale = (2^(nbit-1) - 1) / range_weight_float,因为要减去符号位,因此是nbit-1。

在这个例子中,scale = (2^3 - 1) / 3.5 = 2

(3) 浮点量化到定点

weight_int = round(scale*weight_float)

在这个例子中,weight_int = round(2 * [0.1, 0.2, 1.2, 3, 2.1, -2.1, -3.5]) = [0, 0, 2, 6, 4, -4, -7]

量化大部分情况下还需要反量化(de-quantization)过程,即将量化后的数再转换为浮点数。weight_float_from_int = weight_int / scale

在此例子中,weight_float_from_int = [0, 0, 2, 6, 4, -4, -7] / 2 = [0, 0, 1, 3, 2, -2, -3.5]

下面这个表是float、量化后与反量化后的权重情况 | INT4bit(weight_int) | 0 | 0 | 2 | 6 | 4 | -4 | -7 | | —————————————– | —- | —- | —- | —- | —- | —- | —- | | weight_float_from_int | 0 | 0 | 1 | 3 | 2 | -2 | -3.5 | | weight_float | 0.1 | 0.1 | 1.2 | 3 | 2.1 | -2.1 | -3.5 | | abs(weight_float - weight_float_from_int) | 0.1 | 0.1 | 0.2 | 0 | 0.1 | 0.1 | 0 |

2.非对称量化

假设浮点权重为weight_float = [0.1, 0.2, 1.2, 3, 2.1, -2.1, -3.5]

(1)计算浮点权重的值域 range_weight_float = max(weight_float) - min(weight_float) = 3 - (-3.5) = 6.5

(2)计算scale:scale = (2^nbit - 1) / range_weight_float = (2^4 - 1) / 6.5 = 15 / 6.5 = 2.3077

(3)计算zero-point:zero_point = round(min(weight_float) * scale) = round(-3.5 * 2.3077) = -8

(4)量化浮点到定点:weight_int = round(scale * weight_float) - zero_point = round(2.3077 * [0.1, 0.2, 1.2, 3, 2.1, -2.1, -3.5]) - (-8) = [8, 8, 11, 15, 13, 3, 0]

如果有反量化过程,weight_float_from_int = (weight_int + zero_point) / scale = ([8, 8, 11, 15, 13, 3, 0] + (-8)) / 2.3077 = [0, 0, 1.3, 3.033, 2.167, -2.167, -3.467]

weight_int8811151330
weight_float_from_int001.33.0332.167-2.167-3.467
weight_float0.10.21.232.1-2.1-3.5
abs(weight_float - weight_float_from_int)0.10.20.10.0330.0670.0670.033