Skip to content

Commit 6376572

Browse files
committed
[ET-VK] Add convolution cases to codegen
TSIA Differential Revision: [D55829466](https://our.internmc.facebook.com/intern/diff/D55829466/) ghstack-source-id: 221716571 Pull Request resolved: #2920
1 parent 1300374 commit 6376572

File tree

3 files changed

+98
-3
lines changed

3 files changed

+98
-3
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,78 @@ def get_pool2d_inputs():
5454
return test_suite
5555

5656

57+
def get_conv2d_inputs():
58+
test_suite = VkTestSuite(
59+
[
60+
(
61+
(1, 6, 40, 50),
62+
(8, 6, 3, 3),
63+
(8,),
64+
[1, 2],
65+
[2, 3],
66+
[1, 1],
67+
False,
68+
[0, 0],
69+
1,
70+
),
71+
(
72+
(1, 6, 40, 50),
73+
(6, 8, 3, 3),
74+
(8,),
75+
[1, 2],
76+
[2, 3],
77+
[1, 1],
78+
True,
79+
[0, 1],
80+
1,
81+
),
82+
(
83+
(1, 8, 72, 96),
84+
(8, 1, 3, 3),
85+
(8,),
86+
[1, 1],
87+
[1, 1],
88+
[1, 1],
89+
False,
90+
[0, 0],
91+
8,
92+
),
93+
(
94+
(1, 8, 72, 96),
95+
(8, 8, 1, 1),
96+
(8,),
97+
[1, 1],
98+
[1, 1],
99+
[1, 1],
100+
False,
101+
[0, 0],
102+
1,
103+
),
104+
(
105+
(1, 6, 40, 50),
106+
(8, 6, 3, 3),
107+
None,
108+
[1, 2],
109+
[2, 3],
110+
[1, 1],
111+
False,
112+
[0, 0],
113+
1,
114+
),
115+
]
116+
)
117+
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
118+
return test_suite
119+
120+
57121
test_suites = {
58122
"aten.add.Tensor": get_binary_elementwise_inputs(),
59123
"aten.sub.Tensor": get_binary_elementwise_inputs(),
60124
"aten.div.Tensor": get_binary_elementwise_inputs(),
61125
"aten.mul.Tensor": get_binary_elementwise_inputs(),
62126
"aten.mm.default": get_mm_inputs(),
63127
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
128+
"aten.convolution.default": get_conv2d_inputs(),
64129
}
65130

66131
prepacked_args = {"aten.mm.default": {"mat2"}}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
AT_INT_ARRAY_REF,
1313
AT_SCALAR,
1414
AT_TENSOR,
15+
AT_TENSOR_OPT,
1516
BOOL,
1617
CppTestFileGen,
18+
INT,
1719
TENSOR_TUPLE,
1820
TestSuite,
1921
TestSuiteGen,
@@ -96,7 +98,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
9698
ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default)
9799
)
98100

99-
requires_prepack = "weight" in arg.name
101+
requires_prepack = "weight" in arg.name or "bias" in arg.name
100102
supports_prepack = False
101103
if arg.name in self.suite_def.prepacked_args:
102104
supports_prepack = True
@@ -173,6 +175,23 @@ def create_value_for(self, ref: ValueRefList) -> str:
173175
prepack = self.prepack_ref(ref)
174176

175177
cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
178+
179+
if ref.src_cpp_type == AT_TENSOR_OPT:
180+
ret_str = f"{cpp_type} {ref.name} = "
181+
ret_str += f"!{ref.src_cpp_name}.has_value() ? "
182+
ret_str += f"{self.graph}{self.dot}add_none() : "
183+
if not prepack:
184+
ret_str += f"{self.graph}{self.dot}"
185+
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
186+
ret_str += f"{ref.src_cpp_name}->sizes().vec(), "
187+
ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n"
188+
elif prepack:
189+
ret_str += f"{self.graph}{self.dot}"
190+
ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), "
191+
ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()), "
192+
ret_str += f"{ref.src_cpp_name}->const_data_ptr()); \n"
193+
return ret_str
194+
176195
ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}"
177196
if ref.src_cpp_type == AT_TENSOR and not prepack:
178197
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
@@ -189,6 +208,8 @@ def create_value_for(self, ref: ValueRefList) -> str:
189208
ret_str += f"add_scalar_list({ref.src_cpp_name}.vec()); \n"
190209
elif ref.src_cpp_type == BOOL:
191210
ret_str += f"add_scalar<bool>({ref.src_cpp_name}); \n"
211+
elif ref.src_cpp_type == INT:
212+
ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n"
192213
elif ref.src_cpp_type == TENSOR_TUPLE:
193214
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
194215
else:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
## ATen code patterns ##
1616
########################
1717

18-
AT_TENSOR = "at::Tensor"
19-
AT_SCALAR = "at::Scalar"
2018
AT_INT_ARRAY_REF = "at::IntArrayRef"
19+
AT_SCALAR = "at::Scalar"
20+
AT_TENSOR = "at::Tensor"
21+
AT_TENSOR_OPT = "::std::optional<at::Tensor>"
2122
BOOL = "bool"
23+
INT = "int64_t"
2224
TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
2325

2426
###########################
@@ -116,12 +118,19 @@ def create_input_data(self, arg: Argument, data: Any) -> str:
116118

117119
if cpp_type == AT_TENSOR:
118120
ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);"
121+
elif cpp_type == AT_TENSOR_OPT:
122+
if str(data) == "None":
123+
ret_str += "std::nullopt;"
124+
else:
125+
ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);"
119126
elif cpp_type == AT_SCALAR:
120127
ret_str += f"{data};"
121128
elif cpp_type == AT_INT_ARRAY_REF:
122129
ret_str += f"{init_list_str(data)};"
123130
elif cpp_type == BOOL:
124131
ret_str += f"{str(data).lower()};"
132+
elif cpp_type == INT:
133+
ret_str += f"{str(data).lower()};"
125134
else:
126135
raise RuntimeError(f"Unsupported cpp type {cpp_type}")
127136
return ret_str + "\n"

0 commit comments

Comments
 (0)