Skip to content

Commit 46cb44f

Browse files
committed
[ET-VK] Introduce Vulkan partitioner
## Context Introduce `VulkanPartitioner`. I based the implementation of `VulkanPartitioner` on `ArmPartitioner`. Differential Revision: [D54128090](https://our.internmc.facebook.com/intern/diff/D54128090/) ghstack-source-id: 216263900 Pull Request resolved: #2062
1 parent 906a716 commit 46cb44f

File tree

4 files changed

+116
-44
lines changed

4 files changed

+116
-44
lines changed

backends/vulkan/partitioner/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "vulkan_partitioner",
7+
srcs = [
8+
"vulkan_partitioner.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
"@EXECUTORCH_CLIENTS",
13+
],
14+
deps = [
15+
"//executorch/backends/vulkan:vulkan_preprocess",
16+
"//executorch/exir:delegate",
17+
"//executorch/exir:lib",
18+
"//executorch/exir/backend:partitioner",
19+
"//executorch/exir/backend:utils",
20+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
21+
],
22+
)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import final, List, Optional
8+
9+
import torch
10+
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from torch.export.exported_program import ExportedProgram
19+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
20+
21+
from torch.fx.passes.operator_support import OperatorSupportBase
22+
23+
24+
class VulkanSupportedOperators(OperatorSupportBase):
25+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
26+
supported = node.op == "call_function" and node.target in [
27+
exir_ops.edge.aten.add.Tensor,
28+
exir_ops.edge.aten.div.Tensor,
29+
exir_ops.edge.aten.mul.Tensor,
30+
exir_ops.edge.aten.sub.Tensor,
31+
exir_ops.edge.aten.pow.Tensor_Tensor,
32+
exir_ops.edge.aten.floor_divide.default,
33+
]
34+
return supported
35+
36+
37+
@final
38+
class VulkanPartitioner(Partitioner):
39+
def __init__(self, compile_spec: Optional[List[CompileSpec]] = None) -> None:
40+
if compile_spec is None:
41+
compile_spec = []
42+
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
43+
44+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
45+
# Run the CapabilityBasedPartitioner to return the largest possible
46+
# subgraphs containing the nodes with the tags
47+
partition_tags = {}
48+
49+
capability_partitioner = CapabilityBasedPartitioner(
50+
exported_program.graph_module,
51+
VulkanSupportedOperators(),
52+
allows_single_node_partition=True,
53+
)
54+
partition_list = capability_partitioner.propose_partitions()
55+
for partition in partition_list:
56+
for node in partition.nodes:
57+
tag = f"tag{partition.id}"
58+
node.meta["delegation_tag"] = tag
59+
partition_tags[tag] = self.delegation_spec
60+
61+
return PartitionResult(
62+
tagged_exported_program=exported_program, partition_tags=partition_tags
63+
)

backends/vulkan/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ python_unittest(
1515
deps = [
1616
"//caffe2:torch",
1717
"//executorch/backends/vulkan:vulkan_preprocess",
18+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
1819
"//executorch/exir:lib",
1920
"//executorch/exir/backend:backend_api",
2021
"//executorch/extension/pybindings:portable_lib", # @manual

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import unittest
99
from typing import Tuple
1010

11-
import executorch.exir as exir
1211
import torch
1312

14-
# import the vulkan backend implementation
13+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
1514
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
1615

17-
from executorch.exir import ExecutorchProgram
18-
from executorch.exir.backend.backend_api import to_backend
16+
from executorch.exir import EdgeProgramManager, to_edge
17+
from torch.export import export, ExportedProgram
1918

2019
ctypes.CDLL("libvulkan.so.1")
2120

@@ -51,7 +50,7 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03)
5150

5251
def lower_module_and_test_output(
5352
self,
54-
module: torch.nn.Module,
53+
model: torch.nn.Module,
5554
sample_inputs: Tuple[torch.Tensor],
5655
atol=1e-03,
5756
rtol=1e-01,
@@ -61,36 +60,23 @@ def lower_module_and_test_output(
6160
the given sample inputs. It then runs the lowered module and compares its
6261
outputs with the outputs of the eager module.
6362
"""
64-
edgeir_m = exir.capture(module, sample_inputs, exir.CaptureConfig()).to_edge()
65-
lowered_module = to_backend("VulkanBackend", edgeir_m.exported_program, [])
63+
program: ExportedProgram = export(model, sample_inputs)
64+
edge_program: EdgeProgramManager = to_edge(program)
65+
edge_program = edge_program.to_backend(VulkanPartitioner())
6666

67-
class WrappedModule(torch.nn.Module):
68-
def __init__(self):
69-
super().__init__()
70-
self.one_module = lowered_module
71-
72-
def forward(self, *args):
73-
return self.one_module(*args)
67+
executorch_program = edge_program.to_executorch()
7468

75-
executorch_program: ExecutorchProgram = (
76-
exir.capture(WrappedModule(), sample_inputs, exir.CaptureConfig())
77-
.to_edge()
78-
.to_executorch()
79-
)
80-
81-
# Assert the backend name is vulkan
8269
self.assertEqual(
83-
executorch_program.program.execution_plan[0].delegates[0].id,
70+
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
8471
VulkanBackend.__name__,
8572
)
8673

87-
# Test the model with executor
8874
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
8975
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
9076
inputs_flattened, _ = tree_flatten(sample_inputs)
9177

9278
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
93-
ref_output = module(*sample_inputs)
79+
ref_output = model(*sample_inputs)
9480

9581
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
9682

@@ -192,26 +178,6 @@ def forward(self, x, y):
192178

193179
self.lower_module_and_test_output(div_module, model_inputs)
194180

195-
def test_vulkan_backend_floor_div(self):
196-
class FloorDivModule(torch.nn.Module):
197-
def __init__(self):
198-
super().__init__()
199-
200-
def forward(self, x, y):
201-
z = x // y
202-
return z
203-
204-
floor_div_module = FloorDivModule()
205-
model_inputs = (
206-
torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
207-
torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
208-
)
209-
210-
# absolute tolerance is 1 because of flooring
211-
self.lower_module_and_test_output(
212-
floor_div_module, model_inputs, atol=1.0 + 1e-03
213-
)
214-
215181
def test_vulkan_backend_arithmetic(self):
216182
class ArithmeticModule(torch.nn.Module):
217183
def __init__(self):
@@ -249,3 +215,23 @@ def forward(self, x, y):
249215
)
250216

251217
self.lower_module_and_test_output(pow_module, model_inputs)
218+
219+
def test_vulkan_backend_partial(self):
220+
class SimpleModel(torch.nn.Module):
221+
def __init__(self):
222+
super().__init__()
223+
self.linear = torch.nn.Linear(10, 10)
224+
self.offset_1 = self.weight = torch.rand(
225+
size=(2, 10), dtype=torch.float32
226+
)
227+
self.offset_2 = self.weight = torch.rand(
228+
size=(2, 10), dtype=torch.float32
229+
)
230+
231+
def forward(self, x):
232+
return self.linear(x + self.offset_1) - self.offset_2
233+
234+
model = SimpleModel()
235+
model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),)
236+
237+
self.lower_module_and_test_output(model, model_inputs)

0 commit comments

Comments
 (0)