Skip to content

[ET-VK][14/n] Add operators to Partitioner #3407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.view_copy.default,
# Copy-releated operators
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split.Tensor,
exir_ops.edge.aten.slice_copy.Tensor,
# Other
operator.getitem,
exir_ops.edge.aten.full.default,
Expand Down
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void add_split_with_sizes_default_node(
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
}

void split_with_sizes_default(
void split_with_sizes_copy_default(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]);
Expand Down Expand Up @@ -134,7 +134,8 @@ void split_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default);
VK_REGISTER_OP(
aten.split_with_sizes_copy.default, split_with_sizes_copy_default);
VK_REGISTER_OP(aten.split.Tensor, split_tensor);
}

Expand Down
9 changes: 4 additions & 5 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def create_node_value(self, node: Node) -> int:
new_id = self.create_tensor_value(spec, constant_id)
self.node_to_value_ids[node] = new_id
return new_id
elif isinstance(spec, tuple):
# Create a Value for each element in the tuple, wrap Values in a
# ValueList, and map the Node to the ValueList id.
elif isinstance(spec, list) or isinstance(spec, tuple):
# pyre-ignore[6]: pyre having hard time to infer Node type inside
# the container.
new_id = self.create_value_list_value(spec)
self.node_to_value_ids[node] = new_id
return new_id
Expand Down Expand Up @@ -202,7 +202,7 @@ def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
)
return new_id

def create_value_list_value(self, arg: List[Node] | tuple) -> int:
def create_value_list_value(self, arg: tuple | list) -> int:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.ValueList(
Expand Down Expand Up @@ -242,7 +242,6 @@ def get_or_create_value_for(self, arg: _Argument):
# pyre-ignore[6]
return self.create_scalar_list_value(arg)
elif isinstance(arg, list) and isinstance(arg[0], Node):
# pyre-ignore[6]
return self.create_value_list_value(arg)
elif isinstance(arg, str):
return self.create_string_value(arg)
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,6 @@ def get_split_tensor_inputs():
"aten.clone.default": get_clone_inputs(),
"aten.repeat.default": get_repeat_inputs(),
"aten.cat.default": get_cat_inputs(),
"aten.split_with_sizes.default": get_split_with_sizes_inputs(),
"aten.split_with_sizes_copy.default": get_split_with_sizes_inputs(),
"aten.split.Tensor": get_split_tensor_inputs(),
}
124 changes: 124 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def run_test(memory_layout):
compile_options = {
"memory_layout_override": memory_layout,
}

# At least model should run in eager mode.
model(*sample_inputs)

program: ExportedProgram = export(
model, sample_inputs, dynamic_shapes=dynamic_shapes
)
Expand Down Expand Up @@ -798,3 +802,123 @@ def forward(self, x):
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def DISABLED_test_vulkan_backend_permute_copy(self):
# aten.permute_copy.default is not enabled yet in partitioner
class PermuteModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.permute(x, [3, 0, 2, 1])

sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)

self.lower_module_and_test_output(
PermuteModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_cat(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y, z, w):
return torch.cat([x, y, z, w], dim=1)

sample_inputs = (
torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),
torch.randn(size=(3, 1, 2, 7), dtype=torch.float32),
torch.randn(size=(3, 9, 2, 7), dtype=torch.float32),
torch.randn(size=(3, 3, 2, 7), dtype=torch.float32),
)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_slice(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x[:, 2:9:2, :]

sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_split_with_sizes(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.split(x, (3, 6, 1, 3), dim=1)

sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_split_tensor(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tensor_split(x, 2, dim=1)

sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_clone(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.clone(x)

sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def DISABLED_test_vulkan_backend_t_default(self):
# aten.permute_copy.default is not enabled yet in partitioner
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# torch.t is actually exported as aten::permute.
return torch.t(x)

sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),)

self.lower_module_and_test_output(
TestModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)