Skip to content

Commit 28be9d6

Browse files
yipjustinfacebook-github-bot
authored andcommitted
Improve codegen for aten.permute (#3087)
Summary: Pull Request resolved: #3087 In the generated code, it uses CPU as reference implementation. Tricky part happens when CPU modify the stride for some indexing operations like `permute`, leading the return Tensor with a non-continous stride. When we create a `vk_out` tensor based on this non-continous tensor with `at::empty_like`, the `vk_out` tensor inherits the stride property. Leading to wrong answer when moving data back from staging. As a solution, we add `.continous()` to after `at::empty_like` to revert back to default stride. ghstack-source-id: 222417364 Reviewed By: SS-JIA Differential Revision: D56095204 fbshipit-source-id: d42777ec876e47465c892331b5f854203c9fb8ef
1 parent de00717 commit 28be9d6

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
9696
}
9797

9898
REGISTER_OPERATORS {
99+
VK_REGISTER_OP(aten.permute.default, permute);
99100
VK_REGISTER_OP(aten.permute_copy.default, permute);
100101
}
101102

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,5 +206,6 @@ def get_permute_inputs():
206206
"aten.full.default": get_full_inputs(),
207207
"aten.select.int": get_select_int_inputs(),
208208
"aten.select_copy.int": get_select_int_inputs(),
209+
"aten.permute.default": get_permute_inputs(),
209210
"aten.permute_copy.default": get_permute_inputs(),
210211
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str:
310310
ret_str += self.declare_vk_out_for(r)
311311
return ret_str
312312

313-
return f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name});\n"
313+
ret_str = f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name})"
314+
ret_str += ".contiguous();\n"
315+
return ret_str
314316

315317
def copy_from_staging(self, ref: ValueRefList) -> str:
316318
if isinstance(ref, list):

0 commit comments

Comments
 (0)