Skip to content

Commit 935f873

Browse files
yipjustinfacebook-github-bot
authored andcommitted
ExecuTorch Vulkan floor_div (#1737)
Summary: Pull Request resolved: #1737 X-link: pytorch/pytorch#118428 Add a new operator "floor_div" to ET-Vulkan. bypass-github-pytorch-ci-checks Reviewed By: SS-JIA Differential Revision: D53072722 fbshipit-source-id: 956c0b5c79a4dfb97c506980d2e82fb459772000
1 parent fc0e625 commit 935f873

File tree

5 files changed

+53
-11
lines changed

5 files changed

+53
-11
lines changed

backends/vulkan/VulkanBackend.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class VulkanBackend final : public PyTorchBackendInterface {
4848
vk_arithmetic_op_type_div): {
4949
return at::native::vulkan::arithmetic::OpType::DIV;
5050
}
51+
case (at::vulkan::delegate::VkArithmeticOpType::
52+
vk_arithmetic_op_type_floor_div): {
53+
return at::native::vulkan::arithmetic::OpType::FLOOR_DIV;
54+
}
5155
}
5256
}
5357

backends/vulkan/serialization/schema/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum VkArithmeticOpType : short {
3434
vk_arithmetic_op_type_sub = 1,
3535
vk_arithmetic_op_type_mul = 2,
3636
vk_arithmetic_op_type_div = 3,
37+
vk_arithmetic_op_type_floor_div = 4,
3738
}
3839

3940
table VkArithmeticNode {

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class VkArithmeticOpType(IntEnum):
4646
vk_arithmetic_op_type_sub = 1
4747
vk_arithmetic_op_type_mul = 2
4848
vk_arithmetic_op_type_div = 3
49+
vk_arithmetic_op_type_floor_div = 4
4950

5051

5152
@dataclass

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727

2828

2929
class TestBackends(unittest.TestCase):
30-
def assert_outputs_equal(self, model_output, ref_output):
30+
def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03):
3131
"""
3232
Helper testing function that asserts that the model output and the reference output
3333
are equal with some tolerance. Due to numerical differences between eager mode and
34-
the Vulkan's backend, we relax the detal such that absolute tolerance is 1e-3. and
35-
relative tolerance is 1e-3.
34+
the Vulkan's backend, we relax the detal such that default absolute
35+
tolerance is 1e-3. and default relative tolerance is 1e-3.
3636
"""
3737

3838
# Compare the result from executor and eager mode direclty
@@ -41,20 +41,20 @@ def assert_outputs_equal(self, model_output, ref_output):
4141
self.assertTrue(len(ref_output) == len(model_output))
4242
for i in range(len(ref_output)):
4343
self.assertTrue(
44-
torch.allclose(
45-
model_output[i], ref_output[i], atol=1e-03, rtol=1e-03
46-
)
44+
torch.allclose(model_output[i], ref_output[i], atol=atol, rtol=rtol)
4745
)
4846
else:
4947
# If one output, eager returns tensor while executor tuple of size 1
5048
self.assertTrue(
51-
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
49+
torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
5250
)
5351

5452
def lower_module_and_test_output(
5553
self,
5654
module: torch.nn.Module,
5755
sample_inputs: Tuple[torch.Tensor],
56+
atol=1e-03,
57+
rtol=1e-01,
5858
):
5959
"""
6060
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
@@ -92,7 +92,7 @@ def forward(self, *args):
9292
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
9393
ref_output = module(*sample_inputs)
9494

95-
self.assert_outputs_equal(model_output, ref_output)
95+
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
9696

9797
def test_vulkan_backend_add(self):
9898
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
@@ -192,6 +192,26 @@ def forward(self, x, y):
192192

193193
self.lower_module_and_test_output(div_module, model_inputs)
194194

195+
def test_vulkan_backend_floor_div(self):
196+
class FloorDivModule(torch.nn.Module):
197+
def __init__(self):
198+
super().__init__()
199+
200+
def forward(self, x, y):
201+
z = x // y
202+
return z
203+
204+
floor_div_module = FloorDivModule()
205+
model_inputs = (
206+
torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
207+
torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
208+
)
209+
210+
# absolute tolerance is 1 because of flooring
211+
self.lower_module_and_test_output(
212+
floor_div_module, model_inputs, atol=1.0 + 1e-03
213+
)
214+
195215
def test_vulkan_backend_arithmetic(self):
196216
class ArithmeticModule(torch.nn.Module):
197217
def __init__(self):

backends/vulkan/vulkan_preprocess.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import ctypes
8-
from typing import final, List
8+
from typing import Dict, final, List
99

1010
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1111
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -20,14 +20,17 @@
2020
)
2121
from torch import dtype, float32, Tensor
2222
from torch.fx import Node
23+
from torch.fx.node import Argument
2324

2425
DEFAULT_DEBUG_HANDLE = 65535
2526

2627

2728
@final
2829
class VulkanBackend(BackendDetails):
2930
@staticmethod
30-
def get_vk_op_type(target_name: str) -> vk_graph_schema.VkArithmeticOpType:
31+
def get_vk_op_type(
32+
target_name: str, kwargs: Dict[str, "Argument"]
33+
) -> vk_graph_schema.VkArithmeticOpType:
3134
if target_name == "aten.add.Tensor":
3235
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_add
3336
elif target_name == "aten.sub.Tensor":
@@ -36,6 +39,17 @@ def get_vk_op_type(target_name: str) -> vk_graph_schema.VkArithmeticOpType:
3639
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_mul
3740
elif target_name == "aten.div.Tensor":
3841
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_div
42+
elif target_name == "aten.div.Tensor_mode":
43+
if kwargs.get("rounding_mode", None) == "floor":
44+
return (
45+
vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_floor_div
46+
)
47+
48+
raise AssertionError(
49+
f"Invalid node kwargs for vulkan_preprocess (target_name: {target_name}, "
50+
f"kwargs: {kwargs})"
51+
)
52+
3953
else:
4054
raise AssertionError(
4155
f"Invalid node target name for vulkan_preprocess ({target_name})"
@@ -101,7 +115,9 @@ def assign_non_const_vk_value_id(node: Node) -> int:
101115
input1_id=node_vk_value_ids[node.all_input_nodes[0]],
102116
input2_id=node_vk_value_ids[node.all_input_nodes[1]],
103117
output_id=assign_non_const_vk_value_id(node),
104-
op_type=VulkanBackend.get_vk_op_type(node.target.__name__),
118+
op_type=VulkanBackend.get_vk_op_type(
119+
target_name=node.target.__name__, kwargs=node.kwargs
120+
),
105121
flags=0,
106122
),
107123
debug_handle=node.meta.get(

0 commit comments

Comments
 (0)