@@ -91,6 +91,16 @@ class OperatorConfig(NamedTuple):
91
91
operators : list [OperatorPatternType ]
92
92
93
93
94
+ def is_relu_node (node : Node ) -> bool :
95
+ """
96
+ Check if a given node is a relu node
97
+ """
98
+ return node .op == "call_function" and node .target in [
99
+ torch .ops .aten .relu .default ,
100
+ torch .ops .aten .relu_ .default ,
101
+ ]
102
+
103
+
94
104
def _is_annotated (nodes : list [Node ]):
95
105
"""
96
106
Given a list of nodes (that represents an operator pattern),
@@ -231,10 +241,7 @@ def _annotate_linear_relu(
231
241
weight_qspec = get_weight_qspec (quantization_config )
232
242
bias_qspec = get_bias_qspec (quantization_config )
233
243
for node in gm .graph .nodes :
234
- if node .op != "call_function" or node .target not in [
235
- torch .ops .aten .relu .default ,
236
- torch .ops .aten .relu_ .default ,
237
- ]:
244
+ if not is_relu_node (node ):
238
245
continue
239
246
relu_node = node
240
247
maybe_linear_node = node .args [0 ]
@@ -285,21 +292,28 @@ def _annotate_linear_relu(
285
292
return annotated_partitions
286
293
287
294
288
- @register_annotator ("conv" )
289
- def _annotate_conv (
295
+ def _do_annotate_conv (
290
296
gm : torch .fx .GraphModule ,
291
297
quantization_config : Optional [QuantizationConfig ],
292
298
filter_fn : Optional [Callable [[Node ], bool ]] = None ,
299
+ is_conv_transpose : bool = False ,
293
300
) -> Optional [list [list [Node ]]]:
294
301
annotated_partitions = []
302
+ is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
303
+
295
304
for n in gm .graph .nodes :
296
- if n .op != "call_function" or n .target not in [
297
- torch .ops .aten .conv1d .default ,
298
- torch .ops .aten .conv2d .default ,
299
- ]:
305
+ if not is_conv_node (n ):
300
306
continue
301
307
conv_node = n
302
308
309
+ # This is hacky!
310
+ # We do not want to annotate conv node independently if there is a conv + relu pattern
311
+ # So we skip if the conv node is consumed by a single relu node
312
+ if len (conv_node .users ) == 1 :
313
+ user = list (conv_node .users .keys ())[0 ]
314
+ if is_relu_node (user ):
315
+ continue
316
+
303
317
input_qspec_map = {}
304
318
input_act = conv_node .args [0 ]
305
319
assert isinstance (input_act , Node )
@@ -341,10 +355,7 @@ def _do_annotate_conv_relu(
341
355
):
342
356
annotated_partitions = []
343
357
for n in gm .graph .nodes :
344
- if n .op != "call_function" or n .target not in [
345
- torch .ops .aten .relu .default ,
346
- torch .ops .aten .relu_ .default ,
347
- ]:
358
+ if not is_relu_node (n ):
348
359
continue
349
360
relu_node = n
350
361
maybe_conv_node = n .args [0 ]
@@ -393,6 +404,26 @@ def _do_annotate_conv_relu(
393
404
return annotated_partitions
394
405
395
406
407
+ @register_annotator ("conv" )
408
+ def _annotate_conv (
409
+ gm : torch .fx .GraphModule ,
410
+ quantization_config : Optional [QuantizationConfig ],
411
+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
412
+ ) -> Optional [list [list [Node ]]]:
413
+ return _do_annotate_conv (
414
+ gm , quantization_config , filter_fn , is_conv_transpose = False
415
+ )
416
+
417
+
418
+ @register_annotator ("conv_transpose" )
419
+ def _annotate_transpose_conv (
420
+ gm : torch .fx .GraphModule ,
421
+ quantization_config : Optional [QuantizationConfig ],
422
+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
423
+ ) -> Optional [list [list [Node ]]]:
424
+ return _do_annotate_conv (gm , quantization_config , filter_fn , is_conv_transpose = True )
425
+
426
+
396
427
@register_annotator ("conv_relu" )
397
428
def _annotate_conv_relu (
398
429
gm : torch .fx .GraphModule ,
@@ -744,10 +775,7 @@ def _annotate_add_relu( # noqa: C901
744
775
) -> Optional [list [list [Node ]]]:
745
776
annotated_partitions = []
746
777
for node in gm .graph .nodes :
747
- if node .op != "call_function" or node .target not in [
748
- torch .ops .aten .relu .default ,
749
- torch .ops .aten .relu_ .default ,
750
- ]:
778
+ if not is_relu_node (node ):
751
779
continue
752
780
relu_node = node
753
781
maybe_add = node .args [0 ]
@@ -872,10 +900,7 @@ def _annotate_mul_relu( # noqa: C901
872
900
) -> Optional [list [list [Node ]]]:
873
901
annotated_partitions = []
874
902
for node in gm .graph .nodes :
875
- if node .op != "call_function" or node .target not in [
876
- torch .ops .aten .relu .default ,
877
- torch .ops .aten .relu_ .default ,
878
- ]:
903
+ if not is_relu_node (node ):
879
904
continue
880
905
relu_node = node
881
906
maybe_mul = node .args [0 ]
0 commit comments