Skip to content

Commit 33fed6b

Browse files
committed
[ET-VK] Enable FP16 type in operators
Differential Revision: [D56189470](https://our.internmc.facebook.com/intern/diff/D56189470/) ghstack-source-id: 222684648 Pull Request resolved: #3059
1 parent 75a4d49 commit 33fed6b

12 files changed

+95
-45
lines changed

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ def define_variable(name: str) -> str:
9090

9191

9292
def get_buffer_scalar_type(dtype: str) -> str:
93-
# TODO(ssjia): use float16_t for half types
9493
if dtype == "half":
95-
return "float"
94+
return "float16_t"
9695
# TODO(ssjia): use int8_t for int8 types
9796
elif dtype[-1] == "8":
9897
return dtype[:-1]

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "indexing_utils.h"
2121

22+
$if DTYPE == "half":
23+
#extension GL_EXT_shader_16bit_storage : require
24+
2225
layout(std430) buffer;
2326

2427
layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
#include "indexing_utils.h"
2222

23+
$if DTYPE == "half":
24+
#extension GL_EXT_shader_16bit_storage : require
25+
2326
layout(std430) buffer;
2427

2528
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
@@ -52,20 +55,21 @@ void main() {
5255
const ivec4 buf_indices =
5356
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data);
5457

55-
SCALAR_T val_x = SCALAR_T(buffer_in.data[buf_indices.x]);
56-
SCALAR_T val_y = SCALAR_T(buffer_in.data[buf_indices.y]);
57-
SCALAR_T val_z = SCALAR_T(buffer_in.data[buf_indices.z]);
58-
SCALAR_T val_w = SCALAR_T(buffer_in.data[buf_indices.w]);
59-
60-
VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w);
61-
6258
const int packed_dim_size = get_packed_dim(cpu_sizes.data);
6359
int packed_idx = get_packed_dim(idx);
6460

65-
if (packed_idx + 3 >= packed_dim_size) {
66-
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
67-
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
68-
texel = texel * valid_idx;
61+
VEC4_T texel = VEC4_T(0);
62+
if (packed_idx < packed_dim_size) {
63+
texel.x = SCALAR_T(buffer_in.data[buf_indices.x]);
64+
}
65+
if (packed_idx + 1 < packed_dim_size) {
66+
texel.y = SCALAR_T(buffer_in.data[buf_indices.y]);
67+
}
68+
if (packed_idx + 2 < packed_dim_size) {
69+
texel.z = SCALAR_T(buffer_in.data[buf_indices.z]);
70+
}
71+
if (packed_idx + 3 < packed_dim_size) {
72+
texel.w = SCALAR_T(buffer_in.data[buf_indices.w]);
6973
}
7074

7175
imageStore(image_out, ${get_pos[NDIM]("pos")}, texel);

backends/vulkan/test/op_tests/cases.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_mm_inputs():
4141
],
4242
)
4343
test_suite.prepacked_args = ["mat2"]
44+
test_suite.dtypes = ["at::kFloat"]
4445
return test_suite
4546

4647

@@ -123,7 +124,7 @@ def get_native_layer_norm_inputs():
123124
[
124125
((S1, S2), [S2], (S2), (S2), 0.001),
125126
((M, M1, M2), [M2], (M2), (M2), 0.001),
126-
((L, XL, M1, M2), [M2], (M2), (M2), 0.001),
127+
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
127128
]
128129
)
129130
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
@@ -177,11 +178,3 @@ def get_select_int_inputs():
177178
"aten.full.default": get_full_inputs(),
178179
"aten.select.int": get_select_int_inputs(),
179180
}
180-
181-
prepacked_args = {"aten.mm.default": {"mat2"}}
182-
183-
support_exceptions = {
184-
"aten.max_pool2d_with_indices.default": {
185-
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
186-
},
187-
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
2+
load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
3+
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
24
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
35

46
def define_common_targets(is_fbcode = False):
@@ -43,6 +45,24 @@ def define_common_targets(is_fbcode = False):
4345
default_outs = ["."],
4446
)
4547

48+
pt_operator_library(
49+
name = "all_aten_ops",
50+
check_decl = False,
51+
include_all_operators = True,
52+
)
53+
54+
runtime.cxx_library(
55+
name = "all_aten_ops_lib",
56+
srcs = [],
57+
define_static_target = False,
58+
exported_deps = get_pt_ops_deps(
59+
name = "pt_ops_full",
60+
deps = [
61+
":all_aten_ops",
62+
],
63+
),
64+
)
65+
4666
runtime.cxx_binary(
4767
name = "compute_graph_op_tests_bin",
4868
srcs = [
@@ -52,7 +72,7 @@ def define_common_targets(is_fbcode = False):
5272
deps = [
5373
"//third-party/googletest:gtest_main",
5474
"//executorch/backends/vulkan:vulkan_graph_runtime",
55-
runtime.external_dep_location("libtorch"),
75+
":all_aten_ops_lib",
5676
],
5777
)
5878

@@ -72,6 +92,6 @@ def define_common_targets(is_fbcode = False):
7292
deps = [
7393
"//third-party/googletest:gtest_main",
7494
"//executorch/backends/vulkan:vulkan_graph_runtime",
75-
runtime.external_dep_location("libtorch"),
95+
":all_aten_ops_lib",
7696
],
7797
)

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
8888
self.dot = "->"
8989

9090
self.args = []
91-
self.out = None
9291
self.refs = {}
9392

9493
self.should_prepack = False
@@ -288,6 +287,7 @@ def set_output(self, ref: ValueRefList) -> str:
288287
return ret_str
289288

290289
def virtual_resize(self, ref: ValueRefList) -> str:
290+
assert isinstance(ref, ValueRef)
291291
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
292292
if self.prepack_ref(ref):
293293
return ""
@@ -296,6 +296,7 @@ def virtual_resize(self, ref: ValueRefList) -> str:
296296
return ret_str
297297

298298
def copy_into_staging(self, ref: ValueRefList) -> str:
299+
assert isinstance(ref, ValueRef)
299300
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
300301
if self.prepack_ref(ref):
301302
return ""
@@ -336,7 +337,7 @@ def check_graph_out(self, ref: ValueRefList) -> str:
336337
ret_str += self.check_graph_out(r)
337338
return ret_str
338339

339-
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}));\n"
340+
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));\n"
340341

341342
## Top level code generation
342343

@@ -374,11 +375,19 @@ def gen_graph_exec_code(self) -> str:
374375

375376
return graph_exec
376377

378+
def gen_conditional_skips(self) -> str:
379+
skips = f"if (test_dtype == at::kHalf && "
380+
skips += f"!{self.graph}{self.dot}context()->adapter_ptr()->has_16bit_storage()) {{\n"
381+
skips += " GTEST_SKIP();"
382+
skips += "}\n"
383+
return skips
384+
377385
def gen_op_check_fn(self) -> str:
378386
op_name = self.f.func.name.unambiguous_name()
379387
op_check_fn = self.gen_decl(f"check_{op_name}") + " {"
380388
if self.should_prepack:
381389
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {"
390+
op_check_fn += self.gen_conditional_skips()
382391
op_check_fn += self.gen_graph_build_code()
383392
op_check_fn += self.gen_graph_exec_code()
384393
op_check_fn += self.check_graph_out(self.refs["out"])
@@ -391,19 +400,26 @@ def gen_op_check_fn(self) -> str:
391400
##################################
392401

393402
test_fixture_template = """
394-
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<api::StorageType, api::GPUMemoryLayout>> {{
403+
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
395404
protected:
396405
ComputeGraph* graph;
397406
at::ScalarType test_dtype = at::kFloat;
407+
float rtol = 1e-5;
408+
float atol = 1e-5;
398409
399410
void SetUp() override {{
400411
GraphConfig config;
401412
api::StorageType default_storage_type;
402413
api::GPUMemoryLayout default_memory_layout;
403-
std::tie(default_storage_type, default_memory_layout) = GetParam();
414+
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
404415
config.setStorageTypeOverride(default_storage_type);
405416
config.setMemoryLayoutOverride(default_memory_layout);
406417
graph = new ComputeGraph(config);
418+
419+
if (test_dtype == at::kHalf) {{
420+
rtol = 1e-2;
421+
atol = 1e-2;
422+
}}
407423
}}
408424
409425
void TearDown() override {{
@@ -420,7 +436,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
420436

421437

422438
class VkTestSuiteGen(TestSuiteGen):
423-
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: List[Any]):
439+
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
424440
super().__init__(f, inputs)
425441
self.op_reg_name = op_reg_name
426442
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
@@ -442,6 +458,8 @@ def generate_fixture_cpp(self) -> str:
442458
)
443459

444460
def gen_parameterization(self) -> str:
461+
# pyre-ignore
462+
dtypes = self.suite_def.dtypes
445463
storage_types = self.suite_def.supports["storage_types"]
446464
layouts = self.suite_def.supports["layouts"]
447465

@@ -450,6 +468,7 @@ def gen_parameterization(self) -> str:
450468
StorageLayoutCombos_{self.op_name},
451469
GeneratedOpsTest_{self.op_name},
452470
::testing::Combine(
471+
::testing::Values({', '.join(dtypes)}),
453472
::testing::Values({', '.join(storage_types)}),
454473
::testing::Values({', '.join(layouts)})));
455474
"""

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class TestSuite:
3939
input_cases: List[Any]
4040
prepacked_args = []
4141
requires_prepack = False
42+
dtypes = ["at::kFloat", "at::kHalf"]
4243

4344
def supports_prepack(self):
4445
return len(self.prepacked_args) > 0
@@ -239,6 +240,6 @@ def generate_preamble(self) -> str:
239240
def generate_test_suites_cpp(self) -> str:
240241
return "\n".join([h.generate_suite_cpp() for h in self.suites_gens])
241242

242-
def add_suite(self, f: NativeFunction, test_suite: TestSuite) -> None:
243-
suites_gen = TestSuiteGen(f, test_suite)
243+
def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None:
244+
suites_gen = TestSuiteGen(f, all_input_cases)
244245
self.suites_gens.append(suites_gen)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def assert_outputs_equal(
5454
)
5555
)
5656
else:
57+
print(model_output[0])
58+
print(ref_output)
5759
# If one output, eager returns tensor while executor tuple of size 1
5860
self.assertTrue(
5961
torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
@@ -198,8 +200,8 @@ def forward(self, x, y):
198200

199201
sub_module = SubModule()
200202
sample_inputs = (
201-
torch.rand(size=(2, 3), dtype=torch.float32),
202-
torch.rand(size=(2, 3), dtype=torch.float32),
203+
torch.rand(size=(2, 3), dtype=torch.float16),
204+
torch.rand(size=(2, 3), dtype=torch.float16),
203205
)
204206

205207
self.lower_module_and_test_output(sub_module, sample_inputs)

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ void run_from_gpu_test(
816816
api::ScalarType dtype = api::kFloat,
817817
api::StorageType storage_type = api::StorageType::TEXTURE_3D) {
818818
vTensor vten =
819-
vTensor(api::context(), sizes, api::kFloat, storage_type, memory_layout);
819+
vTensor(api::context(), sizes, dtype, storage_type, memory_layout);
820820

821821
std::string kernel_name("idx_fill_texture");
822822
add_memory_layout_suffix(kernel_name, vten);
@@ -838,16 +838,14 @@ void run_from_gpu_test(
838838
vten.cpu_sizes_ubo()->buffer());
839839
}
840840

841-
api::StorageBuffer staging_buffer(
842-
api::context(), api::kFloat, vten.gpu_numel());
841+
api::StorageBuffer staging_buffer(api::context(), dtype, vten.gpu_numel());
843842

844843
record_image_to_nchw_op(api::context(), vten, staging_buffer.buffer());
845844

846845
submit_to_gpu();
847846

848847
std::vector<T> data_out(staging_buffer.numel());
849-
copy_staging_to_ptr(
850-
staging_buffer, data_out.data(), sizeof(float) * staging_buffer.numel());
848+
copy_staging_to_ptr(staging_buffer, data_out.data(), staging_buffer.nbytes());
851849

852850
for (int i = 0; i < vten.numel(); i++) {
853851
CHECK_VALUE(data_out, i, i);
@@ -861,12 +859,16 @@ void run_to_gpu_test(
861859
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
862860
api::ScalarType dtype = api::kFloat,
863861
api::StorageType storage_type = api::StorageType::TEXTURE_3D) {
862+
if (dtype == api::kHalf &&
863+
!api::context()->adapter_ptr()->has_16bit_storage()) {
864+
return;
865+
}
866+
864867
vTensor vten =
865868
vTensor(api::context(), sizes, api::kFloat, storage_type, memory_layout);
866869

867870
// Create and fill input staging buffer
868-
api::StorageBuffer staging_buffer_in(
869-
api::context(), api::kFloat, vten.gpu_numel());
871+
api::StorageBuffer staging_buffer_in(api::context(), dtype, vten.gpu_numel());
870872

871873
std::vector<T> data_in(staging_buffer_in.numel());
872874
for (int i = 0; i < staging_buffer_in.numel(); i++) {
@@ -876,7 +878,7 @@ void run_to_gpu_test(
876878

877879
// Output staging buffer
878880
api::StorageBuffer staging_buffer_out(
879-
api::context(), api::kFloat, vten.gpu_numel());
881+
api::context(), dtype, vten.gpu_numel());
880882

881883
// Copy data in and out of the tensor
882884
record_nchw_to_image_op(api::context(), staging_buffer_in.buffer(), vten);
@@ -888,9 +890,7 @@ void run_to_gpu_test(
888890
// Extract data from output staging buffer
889891
std::vector<T> data_out(staging_buffer_out.numel());
890892
copy_staging_to_ptr(
891-
staging_buffer_out,
892-
data_out.data(),
893-
sizeof(float) * staging_buffer_out.numel());
893+
staging_buffer_out, data_out.data(), staging_buffer_out.nbytes());
894894

895895
// All indices should be equal to the input data
896896
for (int i = 0; i < vten.numel(); i++) {
@@ -943,7 +943,7 @@ TEST(VulkanToFromGPUShaderTest, to_gpu_and_from_gpu_test_texture) {
943943

944944
for (auto& sizes : to_test) {
945945
RUN_TESTS(float, api::kFloat)
946-
RUN_TESTS(float, api::kHalf)
946+
RUN_TESTS(c10::Half, api::kHalf)
947947
}
948948
#undef RUN_TESTS
949949
}

0 commit comments

Comments
 (0)