Skip to content

Commit 1d07d79

Browse files
Fix pruning walker. (PaddlePaddle#277)
1 parent 9d8730f commit 1d07d79

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

paddleslim/prune/prune_walker.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
PRUNE_WORKER = Registry('prune_worker')
2525

26+
SKIP_OPS = ["conditional_block"]
27+
2628

2729
class PruneWorker(object):
2830
def __init__(self, op, pruned_params=[], visited={}):
@@ -72,6 +74,9 @@ def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
7274
self.visited = visited
7375
cls = PRUNE_WORKER.get(op.type())
7476
if cls is None:
77+
if op.type() in SKIP_OPS:
78+
_logger.warn("Skip operator [{}]".format(op.type()))
79+
return
7580
_logger.warn(
7681
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
7782
format(op.type()))
@@ -149,6 +154,71 @@ def _prune(self, var, pruned_axis, pruned_idx):
149154
self._prune_op(op, output_var, channel_axis, pruned_idx)
150155

151156

157+
@PRUNE_WORKER.register
158+
class conv2d_transpose(PruneWorker):
159+
def __init__(self, op, pruned_params, visited={}):
160+
super(conv2d_transpose, self).__init__(op, pruned_params, visited)
161+
162+
def _prune(self, var, pruned_axis, pruned_idx):
163+
data_format = self.op.attr("data_format")
164+
channel_axis = 1
165+
if data_format == "NHWC":
166+
channel_axis = 3
167+
if var in self.op.inputs("Input"):
168+
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
169+
pruned_axis, var.name())
170+
filter_var = self.op.inputs("Filter")[0]
171+
self._visit(filter_var, 0)
172+
self.pruned_params.append((filter_var, 0, pruned_idx))
173+
for op in filter_var.outputs():
174+
self._prune_op(op, filter_var, 0, pruned_idx)
175+
176+
elif var in self.op.inputs("Filter"):
177+
assert pruned_axis in [0, 1]
178+
179+
self.pruned_params.append((var, pruned_axis, pruned_idx))
180+
181+
for op in var.outputs():
182+
self._prune_op(op, var, pruned_axis, pruned_idx)
183+
184+
if pruned_axis == 1:
185+
if len(self.op.inputs("Bias")) > 0:
186+
self.pruned_params.append(
187+
(self.op.inputs("Bias"), channel_axis, pruned_idx))
188+
output_var = self.op.outputs("Output")[0]
189+
self._visit(output_var, channel_axis)
190+
next_ops = output_var.outputs()
191+
for op in next_ops:
192+
self._prune_op(op, output_var, channel_axis, pruned_idx)
193+
194+
elif pruned_axis == 0:
195+
input_var = self.op.inputs("Input")[0]
196+
self._visit(input_var, channel_axis)
197+
pre_ops = input_var.inputs()
198+
for op in pre_ops:
199+
self._prune_op(op, input_var, channel_axis, pruned_idx)
200+
elif var in self.op.outputs("Output"):
201+
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
202+
pruned_axis, var.name())
203+
204+
filter_var = self.op.inputs("Filter")[0]
205+
self._visit(filter_var, 1)
206+
207+
self.pruned_params.append((filter_var, 1, pruned_idx))
208+
209+
for op in filter_var.outputs():
210+
self._prune_op(op, filter_var, 1, pruned_idx)
211+
212+
if len(self.op.inputs("Bias")) > 0:
213+
self.pruned_params.append(
214+
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
215+
216+
output_var = self.op.outputs("Output")[0]
217+
next_ops = output_var.outputs()
218+
for op in next_ops:
219+
self._prune_op(op, output_var, channel_axis, pruned_idx)
220+
221+
152222
@PRUNE_WORKER.register
153223
class batch_norm(PruneWorker):
154224
def __init__(self, op, pruned_params, visited):
@@ -267,7 +337,7 @@ def __init__(self, op, pruned_params, visited):
267337

268338
def _prune(self, var, pruned_axis, pruned_idx):
269339
if var in self.op.all_outputs():
270-
for in_var in self.op.inputs():
340+
for in_var in self.op.all_inputs():
271341
if len(in_var.shape()) == len(var.shape()):
272342
pre_ops = in_var.inputs()
273343
for op in pre_ops:
@@ -549,3 +619,33 @@ def _prune(self, var, pruned_axis, pruned_idx):
549619
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
550620
moment2_var = self.op.inputs("Moment2")[0]
551621
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
622+
623+
624+
@PRUNE_WORKER.register
625+
class affine_channel(PruneWorker):
626+
def __init__(self, op, pruned_params, visited):
627+
super(affine_channel, self).__init__(op, pruned_params, visited)
628+
629+
def _prune(self, var, pruned_axis, pruned_idx):
630+
if (var not in self.op.outputs("Out")) and (
631+
var not in self.op.inputs("X")):
632+
return
633+
634+
if var in self.op.outputs("Out"):
635+
in_var = self.op.inputs("X")[0]
636+
self._visit(in_var, pruned_axis)
637+
pre_ops = in_var.inputs()
638+
for op in pre_ops:
639+
self._prune_op(op, in_var, pruned_axis, pruned_idx)
640+
641+
for param in ["Scale", "Bias"]:
642+
param_var = self.op.inputs(param)[0]
643+
for op in param_var.outputs():
644+
self._prune_op(op, param_var, 0, pruned_idx)
645+
self.pruned_params.append((param_var, 0, pruned_idx))
646+
647+
out_var = self.op.outputs("Out")[0]
648+
self._visit(out_var, pruned_axis)
649+
next_ops = out_var.outputs()
650+
for op in next_ops:
651+
self._prune_op(op, out_var, pruned_axis, pruned_idx)

paddleslim/prune/pruner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def prune(self,
9090
visited = {}
9191
pruned_params = []
9292
for param, ratio in zip(params, ratios):
93+
if graph.var(param) is None:
94+
_logger.warn(
95+
"Variable[{}] to be pruned is not in current graph.".
96+
format(param))
97+
continue
9398
group = collect_convs([param], graph, visited)[0] # [(name, axis)]
9499
if group is None or len(group) == 0:
95100
continue

0 commit comments

Comments
 (0)