Skip to content

ExecuTorch Vulkan floor_div #1737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/vulkan/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class VulkanBackend final : public PyTorchBackendInterface {
vk_arithmetic_op_type_div): {
return at::native::vulkan::arithmetic::OpType::DIV;
}
case (at::vulkan::delegate::VkArithmeticOpType::
vk_arithmetic_op_type_floor_div): {
return at::native::vulkan::arithmetic::OpType::FLOOR_DIV;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/serialization/schema/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum VkArithmeticOpType : short {
vk_arithmetic_op_type_sub = 1,
vk_arithmetic_op_type_mul = 2,
vk_arithmetic_op_type_div = 3,
vk_arithmetic_op_type_floor_div = 4,
}

table VkArithmeticNode {
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class VkArithmeticOpType(IntEnum):
vk_arithmetic_op_type_sub = 1
vk_arithmetic_op_type_mul = 2
vk_arithmetic_op_type_div = 3
vk_arithmetic_op_type_floor_div = 4


@dataclass
Expand Down
36 changes: 28 additions & 8 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@


class TestBackends(unittest.TestCase):
def assert_outputs_equal(self, model_output, ref_output):
def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03):
"""
Helper testing function that asserts that the model output and the reference output
are equal with some tolerance. Due to numerical differences between eager mode and
the Vulkan's backend, we relax the detal such that absolute tolerance is 1e-3. and
relative tolerance is 1e-3.
the Vulkan's backend, we relax the detal such that default absolute
tolerance is 1e-3. and default relative tolerance is 1e-3.
"""

# Compare the result from executor and eager mode direclty
Expand All @@ -41,20 +41,20 @@ def assert_outputs_equal(self, model_output, ref_output):
self.assertTrue(len(ref_output) == len(model_output))
for i in range(len(ref_output)):
self.assertTrue(
torch.allclose(
model_output[i], ref_output[i], atol=1e-03, rtol=1e-03
)
torch.allclose(model_output[i], ref_output[i], atol=atol, rtol=rtol)
)
else:
# If one output, eager returns tensor while executor tuple of size 1
self.assertTrue(
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
)

def lower_module_and_test_output(
self,
module: torch.nn.Module,
sample_inputs: Tuple[torch.Tensor],
atol=1e-03,
rtol=1e-01,
):
"""
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
Expand Down Expand Up @@ -92,7 +92,7 @@ def forward(self, *args):
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
ref_output = module(*sample_inputs)

self.assert_outputs_equal(model_output, ref_output)
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)

def test_vulkan_backend_add(self):
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
Expand Down Expand Up @@ -192,6 +192,26 @@ def forward(self, x, y):

self.lower_module_and_test_output(div_module, model_inputs)

def test_vulkan_backend_floor_div(self):
class FloorDivModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x // y
return z

floor_div_module = FloorDivModule()
model_inputs = (
torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
)

# absolute tolerance is 1 because of flooring
self.lower_module_and_test_output(
floor_div_module, model_inputs, atol=1.0 + 1e-03
)

def test_vulkan_backend_arithmetic(self):
class ArithmeticModule(torch.nn.Module):
def __init__(self):
Expand Down
22 changes: 19 additions & 3 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import ctypes
from typing import final, List
from typing import Dict, final, List

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand All @@ -20,14 +20,17 @@
)
from torch import dtype, float32, Tensor
from torch.fx import Node
from torch.fx.node import Argument

DEFAULT_DEBUG_HANDLE = 65535


@final
class VulkanBackend(BackendDetails):
@staticmethod
def get_vk_op_type(target_name: str) -> vk_graph_schema.VkArithmeticOpType:
def get_vk_op_type(
target_name: str, kwargs: Dict[str, "Argument"]
) -> vk_graph_schema.VkArithmeticOpType:
if target_name == "aten.add.Tensor":
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_add
elif target_name == "aten.sub.Tensor":
Expand All @@ -36,6 +39,17 @@ def get_vk_op_type(target_name: str) -> vk_graph_schema.VkArithmeticOpType:
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_mul
elif target_name == "aten.div.Tensor":
return vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_div
elif target_name == "aten.div.Tensor_mode":
if kwargs.get("rounding_mode", None) == "floor":
return (
vk_graph_schema.VkArithmeticOpType.vk_arithmetic_op_type_floor_div
)

raise AssertionError(
f"Invalid node kwargs for vulkan_preprocess (target_name: {target_name}, "
f"kwargs: {kwargs})"
)

else:
raise AssertionError(
f"Invalid node target name for vulkan_preprocess ({target_name})"
Expand Down Expand Up @@ -101,7 +115,9 @@ def assign_non_const_vk_value_id(node: Node) -> int:
input1_id=node_vk_value_ids[node.all_input_nodes[0]],
input2_id=node_vk_value_ids[node.all_input_nodes[1]],
output_id=assign_non_const_vk_value_id(node),
op_type=VulkanBackend.get_vk_op_type(node.target.__name__),
op_type=VulkanBackend.get_vk_op_type(
target_name=node.target.__name__, kwargs=node.kwargs
),
flags=0,
),
debug_handle=node.meta.get(
Expand Down