Skip to content

Commit 066813e

Browse files
committed
[ET-VK] Enable FP16 type in operators
Pull Request resolved: #3059 ## Context Enable half precision shader computation using the `GL_EXT_shader_16bit_storage` extension that was enabled in the change just below this stack. ghstack-source-id: 222700889 Differential Revision: [D56189470](https://our.internmc.facebook.com/intern/diff/D56189470/)
1 parent 16a9e41 commit 066813e

File tree

11 files changed

+121
-68
lines changed

11 files changed

+121
-68
lines changed

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ def define_variable(name: str) -> str:
9090

9191

9292
def get_buffer_scalar_type(dtype: str) -> str:
93-
# TODO(ssjia): use float16_t for half types
9493
if dtype == "half":
95-
return "float"
96-
# TODO(ssjia): use int8_t for int8 types
94+
return "float16_t"
9795
elif dtype[-1] == "8":
98-
return dtype[:-1]
96+
return dtype + "_t"
9997

10098
return dtype
10199

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
#include "indexing_utils.h"
2222

23+
$if DTYPE == "half":
24+
#extension GL_EXT_shader_16bit_storage : require
25+
2326
layout(std430) buffer;
2427

2528
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
@@ -52,20 +55,21 @@ void main() {
5255
const ivec4 buf_indices =
5356
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data);
5457

55-
SCALAR_T val_x = SCALAR_T(buffer_in.data[buf_indices.x]);
56-
SCALAR_T val_y = SCALAR_T(buffer_in.data[buf_indices.y]);
57-
SCALAR_T val_z = SCALAR_T(buffer_in.data[buf_indices.z]);
58-
SCALAR_T val_w = SCALAR_T(buffer_in.data[buf_indices.w]);
59-
60-
VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w);
61-
6258
const int packed_dim_size = get_packed_dim(cpu_sizes.data);
6359
int packed_idx = get_packed_dim(idx);
6460

65-
if (packed_idx + 3 >= packed_dim_size) {
66-
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
67-
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
68-
texel = texel * valid_idx;
61+
VEC4_T texel = VEC4_T(0);
62+
if (packed_idx < packed_dim_size) {
63+
texel.x = SCALAR_T(buffer_in.data[buf_indices.x]);
64+
}
65+
if (packed_idx + 1 < packed_dim_size) {
66+
texel.y = SCALAR_T(buffer_in.data[buf_indices.y]);
67+
}
68+
if (packed_idx + 2 < packed_dim_size) {
69+
texel.z = SCALAR_T(buffer_in.data[buf_indices.z]);
70+
}
71+
if (packed_idx + 3 < packed_dim_size) {
72+
texel.w = SCALAR_T(buffer_in.data[buf_indices.w]);
6973
}
7074

7175
imageStore(image_out, ${get_pos[NDIM]("pos")}, texel);

backends/vulkan/test/op_tests/cases.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def get_binary_elementwise_inputs():
24-
return VkTestSuite(
24+
test_suite = VkTestSuite(
2525
[
2626
((M1, M2), (M1, M2)),
2727
((M1, M2), (M1, 1), 2.0),
@@ -31,6 +31,11 @@ def get_binary_elementwise_inputs():
3131
((S, S1, S2), (S, 1, S2), 2.0),
3232
]
3333
)
34+
test_suite.layouts = [
35+
"api::kWidthPacked",
36+
"api::kChannelsPacked",
37+
]
38+
return test_suite
3439

3540

3641
def get_mm_inputs():
@@ -41,6 +46,12 @@ def get_mm_inputs():
4146
],
4247
)
4348
test_suite.prepacked_args = ["mat2"]
49+
# ATen matmul doesn't support half
50+
test_suite.dtypes = ["at::kFloat"]
51+
test_suite.layouts = [
52+
"api::kWidthPacked",
53+
"api::kChannelsPacked",
54+
]
4455
return test_suite
4556

4657

@@ -50,7 +61,6 @@ def get_pool2d_inputs():
5061
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
5162
]
5263
)
53-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
5464
return test_suite
5565

5666

@@ -114,7 +124,6 @@ def get_conv2d_inputs():
114124
),
115125
]
116126
)
117-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
118127
return test_suite
119128

120129

@@ -123,10 +132,9 @@ def get_native_layer_norm_inputs():
123132
[
124133
((S1, S2), [S2], (S2), (S2), 0.001),
125134
((M, M1, M2), [M2], (M2), (M2), 0.001),
126-
((L, XL, M1, M2), [M2], (M2), (M2), 0.001),
135+
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
127136
]
128137
)
129-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
130138
return test_suite
131139

132140

@@ -138,7 +146,6 @@ def get_full_inputs():
138146
([L, M, M1, M2], 2.72),
139147
]
140148
)
141-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
142149
return test_suite
143150

144151

@@ -161,7 +168,6 @@ def get_select_int_inputs():
161168
((8, 6, 1, 1), 1, 4),
162169
]
163170
)
164-
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
165171
return test_suite
166172

167173

@@ -177,11 +183,3 @@ def get_select_int_inputs():
177183
"aten.full.default": get_full_inputs(),
178184
"aten.select.int": get_select_int_inputs(),
179185
}
180-
181-
prepacked_args = {"aten.mm.default": {"mat2"}}
182-
183-
support_exceptions = {
184-
"aten.max_pool2d_with_indices.default": {
185-
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
186-
},
187-
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
2+
load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
3+
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
24
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
35

46
def define_common_targets(is_fbcode = False):
@@ -43,6 +45,24 @@ def define_common_targets(is_fbcode = False):
4345
default_outs = ["."],
4446
)
4547

48+
pt_operator_library(
49+
name = "all_aten_ops",
50+
check_decl = False,
51+
include_all_operators = True,
52+
)
53+
54+
runtime.cxx_library(
55+
name = "all_aten_ops_lib",
56+
srcs = [],
57+
define_static_target = False,
58+
exported_deps = get_pt_ops_deps(
59+
name = "pt_ops_full",
60+
deps = [
61+
":all_aten_ops",
62+
],
63+
),
64+
)
65+
4666
runtime.cxx_binary(
4767
name = "compute_graph_op_tests_bin",
4868
srcs = [
@@ -52,7 +72,7 @@ def define_common_targets(is_fbcode = False):
5272
deps = [
5373
"//third-party/googletest:gtest_main",
5474
"//executorch/backends/vulkan:vulkan_graph_runtime",
55-
runtime.external_dep_location("libtorch"),
75+
":all_aten_ops_lib",
5676
],
5777
)
5878

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@
3939

4040
@dataclass
4141
class VkTestSuite(TestSuite):
42-
supports = {
43-
"storage_types": ["api::StorageType::TEXTURE_3D"],
44-
"layouts": [
45-
"api::GPUMemoryLayout::TENSOR_WIDTH_PACKED",
46-
"api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED",
47-
],
48-
}
42+
def __init__(self, input_cases: List[Any]):
43+
super().__init__(input_cases)
44+
self.storage_types: List[str] = ["api::kTexture3D"]
45+
self.layouts: List[str] = ["api::kChannelsPacked"]
4946

5047

5148
##########################
@@ -88,7 +85,6 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
8885
self.dot = "->"
8986

9087
self.args = []
91-
self.out = None
9288
self.refs = {}
9389

9490
self.should_prepack = False
@@ -288,6 +284,7 @@ def set_output(self, ref: ValueRefList) -> str:
288284
return ret_str
289285

290286
def virtual_resize(self, ref: ValueRefList) -> str:
287+
assert isinstance(ref, ValueRef)
291288
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
292289
if self.prepack_ref(ref):
293290
return ""
@@ -296,6 +293,7 @@ def virtual_resize(self, ref: ValueRefList) -> str:
296293
return ret_str
297294

298295
def copy_into_staging(self, ref: ValueRefList) -> str:
296+
assert isinstance(ref, ValueRef)
299297
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
300298
if self.prepack_ref(ref):
301299
return ""
@@ -336,7 +334,7 @@ def check_graph_out(self, ref: ValueRefList) -> str:
336334
ret_str += self.check_graph_out(r)
337335
return ret_str
338336

339-
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}));\n"
337+
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));\n"
340338

341339
## Top level code generation
342340

@@ -374,11 +372,19 @@ def gen_graph_exec_code(self) -> str:
374372

375373
return graph_exec
376374

375+
def gen_conditional_skips(self) -> str:
376+
skips = "if (test_dtype == at::kHalf && "
377+
skips += f"!{self.graph}{self.dot}context()->adapter_ptr()->has_16bit_storage()) {{\n"
378+
skips += " GTEST_SKIP();"
379+
skips += "}\n"
380+
return skips
381+
377382
def gen_op_check_fn(self) -> str:
378383
op_name = self.f.func.name.unambiguous_name()
379384
op_check_fn = self.gen_decl(f"check_{op_name}") + " {"
380385
if self.should_prepack:
381386
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {"
387+
op_check_fn += self.gen_conditional_skips()
382388
op_check_fn += self.gen_graph_build_code()
383389
op_check_fn += self.gen_graph_exec_code()
384390
op_check_fn += self.check_graph_out(self.refs["out"])
@@ -391,19 +397,26 @@ def gen_op_check_fn(self) -> str:
391397
##################################
392398

393399
test_fixture_template = """
394-
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<api::StorageType, api::GPUMemoryLayout>> {{
400+
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
395401
protected:
396402
ComputeGraph* graph;
397403
at::ScalarType test_dtype = at::kFloat;
404+
float rtol = 1e-5;
405+
float atol = 1e-5;
398406
399407
void SetUp() override {{
400408
GraphConfig config;
401409
api::StorageType default_storage_type;
402410
api::GPUMemoryLayout default_memory_layout;
403-
std::tie(default_storage_type, default_memory_layout) = GetParam();
411+
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
404412
config.setStorageTypeOverride(default_storage_type);
405413
config.setMemoryLayoutOverride(default_memory_layout);
406414
graph = new ComputeGraph(config);
415+
416+
if (test_dtype == at::kHalf) {{
417+
rtol = 1e-2;
418+
atol = 1e-2;
419+
}}
407420
}}
408421
409422
void TearDown() override {{
@@ -420,7 +433,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
420433

421434

422435
class VkTestSuiteGen(TestSuiteGen):
423-
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: List[Any]):
436+
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
424437
super().__init__(f, inputs)
425438
self.op_reg_name = op_reg_name
426439
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
@@ -442,14 +455,16 @@ def generate_fixture_cpp(self) -> str:
442455
)
443456

444457
def gen_parameterization(self) -> str:
445-
storage_types = self.suite_def.supports["storage_types"]
446-
layouts = self.suite_def.supports["layouts"]
458+
dtypes = self.suite_def.dtypes
459+
storage_types = self.suite_def.storage_types
460+
layouts = self.suite_def.layouts
447461

448462
return f"""
449463
INSTANTIATE_TEST_SUITE_P(
450-
StorageLayoutCombos_{self.op_name},
464+
Combos_{self.op_name},
451465
GeneratedOpsTest_{self.op_name},
452466
::testing::Combine(
467+
::testing::Values({', '.join(dtypes)}),
453468
::testing::Values({', '.join(storage_types)}),
454469
::testing::Values({', '.join(layouts)})));
455470
"""
@@ -494,9 +509,11 @@ def gen_parameterization(self) -> str:
494509
return true;
495510
}
496511
bool is_close = at::allclose(t1, t2, rtol, atol);
497-
if (!is_close) {
498-
std::cout << "t1:" << t1 << std::endl;
499-
std::cout << "t2:" << t2 << std::endl;
512+
if (!is_close && t1.numel() < 500) {
513+
std::cout << "reference: " << std::endl;
514+
print(t1, 150);
515+
std::cout << "vulkan: " << std::endl;
516+
print(t2, 150);
500517
}
501518
return is_close;
502519
}

0 commit comments

Comments
 (0)