Skip to content

Commit f54331a

Browse files
authored
Add More Observers (PaddlePaddle#1690)
1 parent 7263316 commit f54331a

File tree

5 files changed

+362
-11
lines changed

5 files changed

+362
-11
lines changed

paddleslim/quant/observers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,10 @@
1414

1515
from .hist import HistObserver
1616
from .kl import KLObserver
17+
from .mse import MSEObserver
18+
from .emd import EMDObserver
19+
from .avg import AVGObserver
1720

18-
__all__ = ["HistObserver", "KLObserver"]
21+
__all__ = [
22+
"HistObserver", "KLObserver", "MSEObserver", "EMDObserver", "AVGObserver"
23+
]

paddleslim/quant/observers/avg.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
from .uniform import UniformObserver
18+
from paddle.quantization.factory import ObserverFactory
19+
20+
21+
class AVGObserver(ObserverFactory):
22+
r"""
23+
It collects maximum absolute values of target tensor.
24+
Args:
25+
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
26+
dtype(str, optional): The data type of input tensor.
27+
name (str, optional): This parameter is used by developers to print debugging information. \
28+
For details, please refer to :ref:`api_guide_Name`. Default is None.
29+
Examples:
30+
.. code-block:: python
31+
from paddle.quantization import QuantConfig
32+
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
33+
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
34+
q_config = QuantConfig(activation=quanter, weight=quanter)
35+
"""
36+
37+
def __init__(self, quant_bits=8):
38+
super(AVGObserver, self).__init__(quant_bits=quant_bits)
39+
40+
def _get_class(self):
41+
return AVGObserverLayer
42+
43+
44+
class AVGObserverLayer(UniformObserver):
45+
def __init__(
46+
self,
47+
layer,
48+
quant_bits=8, ):
49+
super(AVGObserverLayer, self).__init__(quant_bits=quant_bits)
50+
self._quant_bits = quant_bits
51+
self._avg_list = []
52+
53+
def forward(self, inputs):
54+
""" Calculate forward pass.
55+
"""
56+
self._scale = None
57+
self._zero_point = None
58+
self._min = None
59+
self._max = None
60+
self._avg_min, self._avg_max = self.cal_min_max(inputs)
61+
self._avg_list.append(self._avg_max)
62+
63+
return inputs
64+
65+
def cal_min_max(self, inputs):
66+
abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1)))
67+
abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1))))
68+
return 0, abs_avg_value
69+
70+
def cal_thresholds(self):
71+
""" Compute thresholds for MAX function.
72+
"""
73+
self._min, self._max = self._avg_min, paddle.mean(
74+
paddle.to_tensor(self._avg_list))
75+
self._scale, self._zero_point = self.cal_scales_zero_points()
76+
77+
def min_value(self) -> float:
78+
return self._min
79+
80+
def max_value(self) -> float:
81+
return self._max
82+
83+
def bit_length(self):
84+
""" Return the bit length of quantized data.
85+
"""
86+
return self._quant_bits
87+
88+
def quant_axis(self):
89+
""" Return quantization axis.
90+
"""
91+
return -1
92+
93+
def scales(self):
94+
""" Return output scales.
95+
"""
96+
if self._scale is None:
97+
self.cal_thresholds()
98+
return self._scale
99+
100+
def zero_points(self):
101+
""" Return output zero points.
102+
"""
103+
if self._zero_point is None:
104+
self.cal_thresholds()
105+
return self._zero_point

paddleslim/quant/observers/emd.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
from .uniform import UniformObserver
18+
from paddle.quantization.factory import ObserverFactory
19+
20+
21+
class EMDObserver(ObserverFactory):
22+
r"""
23+
It collects maximum absolute values of target tensor.
24+
Args:
25+
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
26+
dtype(str, optional): The data type of input tensor.
27+
name (str, optional): This parameter is used by developers to print debugging information. \
28+
For details, please refer to :ref:`api_guide_Name`. Default is None.
29+
Examples:
30+
.. code-block:: python
31+
from paddle.quantization import QuantConfig
32+
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
33+
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
34+
q_config = QuantConfig(activation=quanter, weight=quanter)
35+
"""
36+
37+
def __init__(self, quant_bits=8):
38+
super(EMDObserver, self).__init__(quant_bits=quant_bits)
39+
40+
def _get_class(self):
41+
return EMDObserverLayer
42+
43+
44+
class EMDObserverLayer(UniformObserver):
45+
def __init__(self, layer, quant_bits=8):
46+
super(EMDObserverLayer, self).__init__(quant_bits=quant_bits)
47+
self._quant_bits = quant_bits
48+
self._calibration_loss = float('inf')
49+
self.qmin, self.qmax = self.qmin_qmax
50+
51+
def forward(self, inputs):
52+
""" Calculate forward pass.
53+
"""
54+
self._scale = None
55+
self._zero_point = None
56+
self._min = None
57+
self._max = None
58+
self._emd_min, self._emd_max = self.cal_min_max(inputs)
59+
60+
return inputs
61+
62+
def cal_min_max(self, inputs):
63+
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
64+
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
65+
s = 0.3
66+
while s <= 1.0:
67+
scale = s * abs_max_value
68+
s += 0.02
69+
bins = 2**(self._quant_bits - 1) - 1
70+
quant_var = paddle.clip(
71+
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
72+
self.qmax)
73+
quant_dequant_var = quant_var / self.qmax * scale
74+
75+
emd_loss = paddle.abs(
76+
paddle.mean(inputs) - paddle.mean(quant_dequant_var)
77+
) + paddle.abs(paddle.std(inputs) - paddle.std(quant_dequant_var))
78+
emd_loss = float(emd_loss)
79+
if emd_loss <= self._calibration_loss:
80+
self._calibration_loss = emd_loss
81+
82+
return 0, scale
83+
84+
def cal_thresholds(self):
85+
""" Compute thresholds for MAX function.
86+
"""
87+
self._min, self._max = self._emd_min, self._emd_max
88+
self._scale, self._zero_point = self.cal_scales_zero_points()
89+
90+
def min_value(self) -> float:
91+
return self._min
92+
93+
def max_value(self) -> float:
94+
return self._max
95+
96+
def bit_length(self):
97+
""" Return the bit length of quantized data.
98+
"""
99+
return self._quant_bits
100+
101+
def quant_axis(self):
102+
""" Return quantization axis.
103+
"""
104+
return -1
105+
106+
def scales(self):
107+
""" Return output scales.
108+
"""
109+
if self._scale is None:
110+
self.cal_thresholds()
111+
return self._scale
112+
113+
def zero_points(self):
114+
""" Return output zero points.
115+
"""
116+
if self._zero_point is None:
117+
self.cal_thresholds()
118+
return self._zero_point

paddleslim/quant/observers/mse.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
from .uniform import UniformObserver
18+
from paddle.quantization.factory import ObserverFactory
19+
20+
21+
class MSEObserver(ObserverFactory):
22+
r"""
23+
It collects maximum absolute values of target tensor.
24+
Args:
25+
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
26+
dtype(str, optional): The data type of input tensor.
27+
name (str, optional): This parameter is used by developers to print debugging information. \
28+
For details, please refer to :ref:`api_guide_Name`. Default is None.
29+
Examples:
30+
.. code-block:: python
31+
from paddle.quantization import QuantConfig
32+
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
33+
quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99)
34+
q_config = QuantConfig(activation=quanter, weight=quanter)
35+
"""
36+
37+
def __init__(self, quant_bits=8):
38+
super(MSEObserver, self).__init__(quant_bits=quant_bits)
39+
40+
def _get_class(self):
41+
return MSEObserverLayer
42+
43+
44+
class MSEObserverLayer(UniformObserver):
45+
def __init__(self, layer, quant_bits=8):
46+
super(MSEObserverLayer, self).__init__(quant_bits=quant_bits)
47+
self.quant_bits = quant_bits
48+
self.calibration_loss = float('inf')
49+
self.qmin, self.qmax = self.qmin_qmax
50+
51+
def forward(self, inputs):
52+
""" Calculate forward pass.
53+
"""
54+
self._scale = None
55+
self._zero_point = None
56+
self._min = None
57+
self._max = None
58+
59+
self._mse_min, self._mse_max = self.cal_min_max(inputs)
60+
61+
return inputs
62+
63+
def cal_min_max(self, inputs):
64+
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
65+
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
66+
s = 0.3
67+
while s <= 1.0:
68+
scale = s * abs_max_value
69+
s += 0.02
70+
quant_var = paddle.clip(
71+
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
72+
self.qmax)
73+
quant_dequant_var = quant_var / self.qmax * scale
74+
75+
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
76+
if mse_loss <= self.calibration_loss:
77+
self.calibration_loss = mse_loss
78+
79+
return 0, scale
80+
81+
def cal_thresholds(self):
82+
""" Compute thresholds for MAX function.
83+
"""
84+
self._min, self._max = self._mse_min, self._mse_max
85+
self._scale, self._zero_point = self.cal_scales_zero_points()
86+
87+
def min_value(self) -> float:
88+
return self._min
89+
90+
def max_value(self) -> float:
91+
return self._max
92+
93+
def bit_length(self):
94+
""" Return the bit length of quantized data.
95+
"""
96+
return self._quant_bits
97+
98+
def quant_axis(self):
99+
""" Return quantization axis.
100+
"""
101+
return -1
102+
103+
def scales(self):
104+
""" Return output scales.
105+
"""
106+
if self._scale is None:
107+
self.cal_thresholds()
108+
return self._scale
109+
110+
def zero_points(self):
111+
""" Return output zero points.
112+
"""
113+
if self._zero_point is None:
114+
self.cal_thresholds()
115+
return self._zero_point

0 commit comments

Comments
 (0)