|
23 | 23 |
|
24 | 24 | PRUNE_WORKER = Registry('prune_worker') |
25 | 25 |
|
| 26 | +SKIP_OPS = ["conditional_block"] |
| 27 | + |
26 | 28 |
|
27 | 29 | class PruneWorker(object): |
28 | 30 | def __init__(self, op, pruned_params=[], visited={}): |
@@ -72,6 +74,9 @@ def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None): |
72 | 74 | self.visited = visited |
73 | 75 | cls = PRUNE_WORKER.get(op.type()) |
74 | 76 | if cls is None: |
| 77 | + if op.type() in SKIP_OPS: |
| 78 | + _logger.warn("Skip operator [{}]".format(op.type())) |
| 79 | + return |
75 | 80 | _logger.warn( |
76 | 81 | "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". |
77 | 82 | format(op.type())) |
@@ -149,6 +154,71 @@ def _prune(self, var, pruned_axis, pruned_idx): |
149 | 154 | self._prune_op(op, output_var, channel_axis, pruned_idx) |
150 | 155 |
|
151 | 156 |
|
| 157 | +@PRUNE_WORKER.register |
| 158 | +class conv2d_transpose(PruneWorker): |
| 159 | + def __init__(self, op, pruned_params, visited={}): |
| 160 | + super(conv2d_transpose, self).__init__(op, pruned_params, visited) |
| 161 | + |
| 162 | + def _prune(self, var, pruned_axis, pruned_idx): |
| 163 | + data_format = self.op.attr("data_format") |
| 164 | + channel_axis = 1 |
| 165 | + if data_format == "NHWC": |
| 166 | + channel_axis = 3 |
| 167 | + if var in self.op.inputs("Input"): |
| 168 | + assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( |
| 169 | + pruned_axis, var.name()) |
| 170 | + filter_var = self.op.inputs("Filter")[0] |
| 171 | + self._visit(filter_var, 0) |
| 172 | + self.pruned_params.append((filter_var, 0, pruned_idx)) |
| 173 | + for op in filter_var.outputs(): |
| 174 | + self._prune_op(op, filter_var, 0, pruned_idx) |
| 175 | + |
| 176 | + elif var in self.op.inputs("Filter"): |
| 177 | + assert pruned_axis in [0, 1] |
| 178 | + |
| 179 | + self.pruned_params.append((var, pruned_axis, pruned_idx)) |
| 180 | + |
| 181 | + for op in var.outputs(): |
| 182 | + self._prune_op(op, var, pruned_axis, pruned_idx) |
| 183 | + |
| 184 | + if pruned_axis == 1: |
| 185 | + if len(self.op.inputs("Bias")) > 0: |
| 186 | + self.pruned_params.append( |
| 187 | + (self.op.inputs("Bias"), channel_axis, pruned_idx)) |
| 188 | + output_var = self.op.outputs("Output")[0] |
| 189 | + self._visit(output_var, channel_axis) |
| 190 | + next_ops = output_var.outputs() |
| 191 | + for op in next_ops: |
| 192 | + self._prune_op(op, output_var, channel_axis, pruned_idx) |
| 193 | + |
| 194 | + elif pruned_axis == 0: |
| 195 | + input_var = self.op.inputs("Input")[0] |
| 196 | + self._visit(input_var, channel_axis) |
| 197 | + pre_ops = input_var.inputs() |
| 198 | + for op in pre_ops: |
| 199 | + self._prune_op(op, input_var, channel_axis, pruned_idx) |
| 200 | + elif var in self.op.outputs("Output"): |
| 201 | + assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( |
| 202 | + pruned_axis, var.name()) |
| 203 | + |
| 204 | + filter_var = self.op.inputs("Filter")[0] |
| 205 | + self._visit(filter_var, 1) |
| 206 | + |
| 207 | + self.pruned_params.append((filter_var, 1, pruned_idx)) |
| 208 | + |
| 209 | + for op in filter_var.outputs(): |
| 210 | + self._prune_op(op, filter_var, 1, pruned_idx) |
| 211 | + |
| 212 | + if len(self.op.inputs("Bias")) > 0: |
| 213 | + self.pruned_params.append( |
| 214 | + (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) |
| 215 | + |
| 216 | + output_var = self.op.outputs("Output")[0] |
| 217 | + next_ops = output_var.outputs() |
| 218 | + for op in next_ops: |
| 219 | + self._prune_op(op, output_var, channel_axis, pruned_idx) |
| 220 | + |
| 221 | + |
152 | 222 | @PRUNE_WORKER.register |
153 | 223 | class batch_norm(PruneWorker): |
154 | 224 | def __init__(self, op, pruned_params, visited): |
@@ -267,7 +337,7 @@ def __init__(self, op, pruned_params, visited): |
267 | 337 |
|
268 | 338 | def _prune(self, var, pruned_axis, pruned_idx): |
269 | 339 | if var in self.op.all_outputs(): |
270 | | - for in_var in self.op.inputs(): |
| 340 | + for in_var in self.op.all_inputs(): |
271 | 341 | if len(in_var.shape()) == len(var.shape()): |
272 | 342 | pre_ops = in_var.inputs() |
273 | 343 | for op in pre_ops: |
@@ -549,3 +619,33 @@ def _prune(self, var, pruned_axis, pruned_idx): |
549 | 619 | self.pruned_params.append((moment1_var, pruned_axis, pruned_idx)) |
550 | 620 | moment2_var = self.op.inputs("Moment2")[0] |
551 | 621 | self.pruned_params.append((moment2_var, pruned_axis, pruned_idx)) |
| 622 | + |
| 623 | + |
| 624 | +@PRUNE_WORKER.register |
| 625 | +class affine_channel(PruneWorker): |
| 626 | + def __init__(self, op, pruned_params, visited): |
| 627 | + super(affine_channel, self).__init__(op, pruned_params, visited) |
| 628 | + |
| 629 | + def _prune(self, var, pruned_axis, pruned_idx): |
| 630 | + if (var not in self.op.outputs("Out")) and ( |
| 631 | + var not in self.op.inputs("X")): |
| 632 | + return |
| 633 | + |
| 634 | + if var in self.op.outputs("Out"): |
| 635 | + in_var = self.op.inputs("X")[0] |
| 636 | + self._visit(in_var, pruned_axis) |
| 637 | + pre_ops = in_var.inputs() |
| 638 | + for op in pre_ops: |
| 639 | + self._prune_op(op, in_var, pruned_axis, pruned_idx) |
| 640 | + |
| 641 | + for param in ["Scale", "Bias"]: |
| 642 | + param_var = self.op.inputs(param)[0] |
| 643 | + for op in param_var.outputs(): |
| 644 | + self._prune_op(op, param_var, 0, pruned_idx) |
| 645 | + self.pruned_params.append((param_var, 0, pruned_idx)) |
| 646 | + |
| 647 | + out_var = self.op.outputs("Out")[0] |
| 648 | + self._visit(out_var, pruned_axis) |
| 649 | + next_ops = out_var.outputs() |
| 650 | + for op in next_ops: |
| 651 | + self._prune_op(op, out_var, pruned_axis, pruned_idx) |
0 commit comments