Skip to content

[ET-VK] Introduce Vulkan partitioner #2062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "vulkan_partitioner",
srcs = [
"vulkan_partitioner.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:delegate",
"//executorch/exir:lib",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend:utils",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)
63 changes: 63 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import final, List, Optional

import torch
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.operator_support import OperatorSupportBase


class VulkanSupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.floor_divide.default,
]
return supported


@final
class VulkanPartitioner(Partitioner):
def __init__(self, compile_spec: Optional[List[CompileSpec]] = None) -> None:
if compile_spec is None:
compile_spec = []
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
partition_tags = {}

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
VulkanSupportedOperators(),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
2 changes: 1 addition & 1 deletion backends/vulkan/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ python_unittest(
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/kernels/portable:custom_ops_generated_lib",
Expand Down
74 changes: 30 additions & 44 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import unittest
from typing import Tuple

import executorch.exir as exir
import torch

# import the vulkan backend implementation
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend

from executorch.exir import ExecutorchProgram
from executorch.exir.backend.backend_api import to_backend
from executorch.exir import EdgeProgramManager, to_edge
from torch.export import export, ExportedProgram

ctypes.CDLL("libvulkan.so.1")

Expand Down Expand Up @@ -51,7 +50,7 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03)

def lower_module_and_test_output(
self,
module: torch.nn.Module,
model: torch.nn.Module,
sample_inputs: Tuple[torch.Tensor],
atol=1e-03,
rtol=1e-01,
Expand All @@ -61,36 +60,23 @@ def lower_module_and_test_output(
the given sample inputs. It then runs the lowered module and compares its
outputs with the outputs of the eager module.
"""
edgeir_m = exir.capture(module, sample_inputs, exir.CaptureConfig()).to_edge()
lowered_module = to_backend("VulkanBackend", edgeir_m.exported_program, [])
program: ExportedProgram = export(model, sample_inputs)
edge_program: EdgeProgramManager = to_edge(program)
edge_program = edge_program.to_backend(VulkanPartitioner())

class WrappedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.one_module = lowered_module

def forward(self, *args):
return self.one_module(*args)
executorch_program = edge_program.to_executorch()

executorch_program: ExecutorchProgram = (
exir.capture(WrappedModule(), sample_inputs, exir.CaptureConfig())
.to_edge()
.to_executorch()
)

# Assert the backend name is vulkan
self.assertEqual(
executorch_program.program.execution_plan[0].delegates[0].id,
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
VulkanBackend.__name__,
)

# Test the model with executor
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(sample_inputs)

model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
ref_output = module(*sample_inputs)
ref_output = model(*sample_inputs)

self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -192,26 +178,6 @@ def forward(self, x, y):

self.lower_module_and_test_output(div_module, model_inputs)

def test_vulkan_backend_floor_div(self):
class FloorDivModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x // y
return z

floor_div_module = FloorDivModule()
model_inputs = (
torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
)

# absolute tolerance is 1 because of flooring
self.lower_module_and_test_output(
floor_div_module, model_inputs, atol=1.0 + 1e-03
)

def test_vulkan_backend_arithmetic(self):
class ArithmeticModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -249,3 +215,23 @@ def forward(self, x, y):
)

self.lower_module_and_test_output(pow_module, model_inputs)

def test_vulkan_backend_partial(self):
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.offset_1 = self.weight = torch.rand(
size=(2, 10), dtype=torch.float32
)
self.offset_2 = self.weight = torch.rand(
size=(2, 10), dtype=torch.float32
)

def forward(self, x):
return self.linear(x + self.offset_1) - self.offset_2

model = SimpleModel()
model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),)

self.lower_module_and_test_output(model, model_inputs)