File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -61,7 +61,7 @@ def __init__(
61
61
self .buffer_limit = buffer_limit
62
62
self .require_dynamic_shapes = require_dynamic_shape
63
63
64
- def op_node_is_compatible (
64
+ def op_node_is_compatible ( # noqa: C901: Function is too complex
65
65
self , node : torch .fx .Node , features : Optional [OpFeatures ] = None
66
66
) -> Tuple [bool , str ]:
67
67
"""
@@ -98,8 +98,12 @@ def op_node_is_compatible(
98
98
and utils .is_tensor_node (arg )
99
99
and i not in features .skip_limits_check
100
100
):
101
+ # Check for bool inputs
102
+ if utils .tensor_node_is_bool (arg ):
103
+ return False , "contains bool tensor"
104
+
101
105
# Check for high dimensional tensors
102
- if utils .is_tensor_node ( arg ) and utils . tensor_node_is_high_dim (arg ):
106
+ if utils .tensor_node_is_high_dim (arg ):
103
107
return False , "contains high dim tensor"
104
108
105
109
arg_texture_layouts = utils .possible_node_memory_layouts (
Original file line number Diff line number Diff line change @@ -80,6 +80,20 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
80
80
return False
81
81
82
82
83
+ def tensor_node_is_bool (node : torch .fx .Node ) -> bool :
84
+ """
85
+ Returns true if a given node contains a tensor with bool dtype
86
+ """
87
+ if isinstance (node .meta ["val" ], FakeTensor ):
88
+ return node .meta ["val" ].dtype == torch .bool
89
+ if isinstance (node .meta ["val" ], list ) or isinstance (node .meta ["val" ], tuple ):
90
+ for fake_tensor in node .meta ["val" ]:
91
+ if isinstance (fake_tensor , FakeTensor ):
92
+ if fake_tensor .dtype == torch .bool :
93
+ return True
94
+ return False
95
+
96
+
83
97
##
84
98
## Memory Layout, Storage Type Determination
85
99
##
You can’t perform that action at this time.
0 commit comments