Skip to content

Commit c2ad2d5

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
aten.minimum.default
Summary: Implement aten.minimum.default op needed in OCR word detector full model. Reviewed By: jorgep31415 Differential Revision: D59541188
1 parent 4027c1b commit c2ad2d5

File tree

5 files changed

+40
-0
lines changed

5 files changed

+40
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __contains__(self, op):
4545
BINARY_OPS = [
4646
exir_ops.edge.aten.add.Tensor,
4747
exir_ops.edge.aten.sub.Tensor,
48+
exir_ops.edge.aten.minimum.default,
4849
exir_ops.edge.aten.mul.Tensor,
4950
exir_ops.edge.aten.div.Tensor,
5051
exir_ops.edge.aten.div.Tensor_mode,

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ binary_op:
2828
OPERATOR: pow(X, Y)
2929
- NAME: binary_floor_divide
3030
OPERATOR: floor(X / Y)
31+
- NAME: binary_minimum
32+
OPERATOR: min(X, Y)

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide);
118118
DEFINE_BINARY_OP_FN(mul);
119119
DEFINE_BINARY_OP_FN(div);
120120
DEFINE_BINARY_OP_FN(pow);
121+
DEFINE_BINARY_OP_FN(minimum);
121122

122123
REGISTER_OPERATORS {
123124
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -126,6 +127,7 @@ REGISTER_OPERATORS {
126127
VK_REGISTER_OP(aten.div.Tensor, div);
127128
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
128129
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
130+
VK_REGISTER_OP(aten.minimum.default, minimum);
129131
}
130132

131133
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,19 @@ def get_constant_pad_nd_inputs():
10221022
]
10231023
)
10241024
return test_suite
1025+
1026+
1027+
@register_test_suite("aten.minimum.default")
1028+
def get_minimum_inputs():
1029+
test_suite = VkTestSuite(
1030+
[
1031+
((M1, M2), (M2)),
1032+
((M1, M2), (M1, M2)),
1033+
((M1, M2, M), (M2, M)),
1034+
((M1, M1, S1, S2), (M1, M1, S1, S2)),
1035+
((S1, S1, S2, S), (S1, S2, S)),
1036+
((M1, S1, S2), (L, M1, S1, S2)),
1037+
((S1, S2), (L, M1, S1, S2)),
1038+
]
1039+
)
1040+
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,25 @@ def forward(self, x):
10741074
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
10751075
)
10761076

1077+
def test_vulkan_backend_minimum(self):
1078+
class MinimumModule(torch.nn.Module):
1079+
def __init__(self):
1080+
super().__init__()
1081+
1082+
def forward(self, x, y):
1083+
return torch.minimum(x, y)
1084+
1085+
sample_inputs = (
1086+
torch.rand(size=(3, 5, 6, 4), dtype=torch.float32),
1087+
torch.rand(size=(6, 4), dtype=torch.float32),
1088+
)
1089+
1090+
self.lower_module_and_test_output(
1091+
MinimumModul(),
1092+
sample_inputs,
1093+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1094+
)
1095+
10771096
def test_vulkan_backend_reshape(self):
10781097
class ReshapeModule(torch.nn.Module):
10791098
def __init__(self):

0 commit comments

Comments
 (0)