Skip to content

Commit 3a90aa6

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Fix floor_div delegate test (#2101)
Summary: Pull Request resolved: #2101 In [PR #2062](#2062), we introduced the partitioner and removed this failing test. The test fails because we were using the wrong op name. We fix it to that from [PR #1737](#1737). ghstack-source-id: 216477874 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54206402 fbshipit-source-id: 0c9ae2af9a380e8aa0e28b107a33ccd36a89033e
1 parent 722af90 commit 3a90aa6

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2626
supported = node.op == "call_function" and node.target in [
2727
exir_ops.edge.aten.add.Tensor,
2828
exir_ops.edge.aten.div.Tensor,
29+
exir_ops.edge.aten.div.Tensor_mode,
2930
exir_ops.edge.aten.mul.Tensor,
3031
exir_ops.edge.aten.sub.Tensor,
3132
exir_ops.edge.aten.pow.Tensor_Tensor,
32-
exir_ops.edge.aten.floor_divide.default,
3333
]
3434
return supported
3535

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,26 @@ def forward(self, x, y):
199199

200200
self.lower_module_and_test_output(arithmetic_module, model_inputs)
201201

202+
def test_vulkan_backend_floor_div(self):
203+
class FloorDivModule(torch.nn.Module):
204+
def __init__(self):
205+
super().__init__()
206+
207+
def forward(self, x, y):
208+
z = x // y
209+
return z
210+
211+
floor_div_module = FloorDivModule()
212+
model_inputs = (
213+
torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
214+
torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
215+
)
216+
217+
# absolute tolerance is 1 because of flooring
218+
self.lower_module_and_test_output(
219+
floor_div_module, model_inputs, atol=1.0 + 1e-03
220+
)
221+
202222
def test_vulkan_backend_pow(self):
203223
class PowModule(torch.nn.Module):
204224
def __init__(self):

0 commit comments

Comments
 (0)