We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a6cd1dc commit 1873a6aCopy full SHA for 1873a6a
backends/arm/quantizer/arm_quantizer.py
@@ -191,14 +191,16 @@ def _get_module_type_filter(tp: Callable) -> NodeFilterType:
191
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
192
"""
193
194
+ tp_str = tp.__module__ + "." + tp.__qualname__
195
+
196
def module_type_filter(n: Node) -> bool:
197
# node_stack example: {
198
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
199
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
200
# }
201
nn_module_stack = n.meta.get("nn_module_stack", {})
202
types = [t for _, t in nn_module_stack.values()]
- return tp in types
203
+ return tp_str in types
204
205
return module_type_filter
206
0 commit comments