Skip to content

Commit 06c2a13

Browse files
committed
Update on "[ET-VK] Enable FP16 type in operators"
Differential Revision: [D56189470](https://our.internmc.facebook.com/intern/diff/D56189470/) [ghstack-poisoned]
2 parents ab6f4a8 + 7b2a9fc commit 06c2a13

File tree

8 files changed

+43
-34
lines changed

8 files changed

+43
-34
lines changed

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// NOLINTBEGIN(clang-diagnostic-missing-field-initializers)
10+
911
#include <executorch/backends/vulkan/runtime/api/Adapter.h>
1012

1113
#include <bitset>

backends/vulkan/runtime/api/Tensor.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,14 @@ vTensor::vTensor(
228228
memory_layout_,
229229
gpu_sizes_,
230230
dtype_,
231-
allocate_memory)) {}
231+
allocate_memory)) {
232+
if (dtype == api::kHalf) {
233+
VK_CHECK_COND(
234+
api::context()->adapter_ptr()->has_16bit_storage(),
235+
"Half dtype is only available if the physical device supports float16 "
236+
"storage buffers!");
237+
}
238+
}
232239

233240
vTensor::vTensor(
234241
api::Context* const context,

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ def define_variable(name: str) -> str:
9292
def get_buffer_scalar_type(dtype: str) -> str:
9393
if dtype == "half":
9494
return "float16_t"
95-
# TODO(ssjia): use int8_t for int8 types
9695
elif dtype[-1] == "8":
97-
return dtype[:-1]
96+
return dtype + "_t"
9897

9998
return dtype
10099

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def get_binary_elementwise_inputs():
24-
return VkTestSuite(
24+
test_suite = VkTestSuite(
2525
[
2626
((M1, M2), (M1, M2)),
2727
((M1, M2), (M1, 1), 2.0),
@@ -31,6 +31,11 @@ def get_binary_elementwise_inputs():
3131
((S, S1, S2), (S, 1, S2), 2.0),
3232
]
3333
)
34+
test_suite.layouts = [
35+
"api::kWidthPacked",
36+
"api::kChannelsPacked",
37+
]
38+
return test_suite
3439

3540

3641
def get_mm_inputs():
@@ -41,7 +46,12 @@ def get_mm_inputs():
4146
],
4247
)
4348
test_suite.prepacked_args = ["mat2"]
49+
# ATen matmul doesn't support half
4450
test_suite.dtypes = ["at::kFloat"]
51+
test_suite.layouts = [
52+
"api::kWidthPacked",
53+
"api::kChannelsPacked",
54+
]
4555
return test_suite
4656

4757

@@ -51,7 +61,6 @@ def get_pool2d_inputs():
5161
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
5262
]
5363
)
54-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
5564
return test_suite
5665

5766

@@ -115,7 +124,6 @@ def get_conv2d_inputs():
115124
),
116125
]
117126
)
118-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
119127
return test_suite
120128

121129

@@ -127,7 +135,6 @@ def get_native_layer_norm_inputs():
127135
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
128136
]
129137
)
130-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
131138
return test_suite
132139

133140

@@ -139,7 +146,6 @@ def get_full_inputs():
139146
([L, M, M1, M2], 2.72),
140147
]
141148
)
142-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
143149
return test_suite
144150

145151

@@ -162,7 +168,6 @@ def get_select_int_inputs():
162168
((8, 6, 1, 1), 1, 4),
163169
]
164170
)
165-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
166171
return test_suite
167172

168173

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,6 @@ def define_common_targets(is_fbcode = False):
9292
deps = [
9393
"//third-party/googletest:gtest_main",
9494
"//executorch/backends/vulkan:vulkan_graph_runtime",
95-
":all_aten_ops_lib",
95+
runtime.external_dep_location("libtorch"),
9696
],
9797
)

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@
3939

4040
@dataclass
4141
class VkTestSuite(TestSuite):
42-
supports = {
43-
"storage_types": ["api::StorageType::TEXTURE_3D"],
44-
"layouts": [
45-
"api::GPUMemoryLayout::TENSOR_WIDTH_PACKED",
46-
"api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED",
47-
],
48-
}
42+
def __init__(self, input_cases: List[Any]):
43+
super().__init__(input_cases)
44+
self.storage_types: List[str] = ["api::kTexture3D"]
45+
self.layouts: List[str] = ["api::kChannelsPacked"]
4946

5047

5148
##########################
@@ -376,7 +373,7 @@ def gen_graph_exec_code(self) -> str:
376373
return graph_exec
377374

378375
def gen_conditional_skips(self) -> str:
379-
skips = f"if (test_dtype == at::kHalf && "
376+
skips = "if (test_dtype == at::kHalf && "
380377
skips += f"!{self.graph}{self.dot}context()->adapter_ptr()->has_16bit_storage()) {{\n"
381378
skips += " GTEST_SKIP();"
382379
skips += "}\n"
@@ -458,14 +455,13 @@ def generate_fixture_cpp(self) -> str:
458455
)
459456

460457
def gen_parameterization(self) -> str:
461-
# pyre-ignore
462458
dtypes = self.suite_def.dtypes
463-
storage_types = self.suite_def.supports["storage_types"]
464-
layouts = self.suite_def.supports["layouts"]
459+
storage_types = self.suite_def.storage_types
460+
layouts = self.suite_def.layouts
465461

466462
return f"""
467463
INSTANTIATE_TEST_SUITE_P(
468-
StorageLayoutCombos_{self.op_name},
464+
Combos_{self.op_name},
469465
GeneratedOpsTest_{self.op_name},
470466
::testing::Combine(
471467
::testing::Values({', '.join(dtypes)}),
@@ -513,9 +509,11 @@ def gen_parameterization(self) -> str:
513509
return true;
514510
}
515511
bool is_close = at::allclose(t1, t2, rtol, atol);
516-
if (!is_close) {
517-
std::cout << "t1:" << t1 << std::endl;
518-
std::cout << "t2:" << t2 << std::endl;
512+
if (!is_close && t1.numel() < 500) {
513+
std::cout << "reference: " << std::endl;
514+
print(t1, 150);
515+
std::cout << "vulkan: " << std::endl;
516+
print(t2, 150);
519517
}
520518
return is_close;
521519
}

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
###########################
3535

3636

37-
@dataclass
3837
class TestSuite:
39-
input_cases: List[Any]
40-
prepacked_args = []
41-
requires_prepack = False
42-
dtypes = ["at::kFloat", "at::kHalf"]
38+
def __init__(self, input_cases: List[Any]):
39+
self.input_cases: List[Any] = input_cases
40+
self.prepacked_args: List[str] = []
41+
self.requires_prepack: bool = False
42+
self.dtypes: List[str] = ["at::kFloat", "at::kHalf"]
4343

4444
def supports_prepack(self):
4545
return len(self.prepacked_args) > 0

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def assert_outputs_equal(
5454
)
5555
)
5656
else:
57-
print(model_output[0])
58-
print(ref_output)
5957
# If one output, eager returns tensor while executor tuple of size 1
6058
self.assertTrue(
6159
torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
@@ -200,8 +198,8 @@ def forward(self, x, y):
200198

201199
sub_module = SubModule()
202200
sample_inputs = (
203-
torch.rand(size=(2, 3), dtype=torch.float16),
204-
torch.rand(size=(2, 3), dtype=torch.float16),
201+
torch.rand(size=(2, 3), dtype=torch.float32),
202+
torch.rand(size=(2, 3), dtype=torch.float32),
205203
)
206204

207205
self.lower_module_and_test_output(sub_module, sample_inputs)

0 commit comments

Comments
 (0)