Skip to content

[ET-VK] Enable FP16 type in operators #3059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions backends/vulkan/runtime/api/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,10 @@ def define_variable(name: str) -> str:


def get_buffer_scalar_type(dtype: str) -> str:
# TODO(ssjia): use float16_t for half types
if dtype == "half":
return "float"
# TODO(ssjia): use int8_t for int8 types
return "float16_t"
elif dtype[-1] == "8":
return dtype[:-1]
return dtype + "_t"

return dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;
Expand Down
26 changes: 15 additions & 11 deletions backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
Expand Down Expand Up @@ -52,20 +55,21 @@ void main() {
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data);

SCALAR_T val_x = SCALAR_T(buffer_in.data[buf_indices.x]);
SCALAR_T val_y = SCALAR_T(buffer_in.data[buf_indices.y]);
SCALAR_T val_z = SCALAR_T(buffer_in.data[buf_indices.z]);
SCALAR_T val_w = SCALAR_T(buffer_in.data[buf_indices.w]);

VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w);

const int packed_dim_size = get_packed_dim(cpu_sizes.data);
int packed_idx = get_packed_dim(idx);

if (packed_idx + 3 >= packed_dim_size) {
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
texel = texel * valid_idx;
VEC4_T texel = VEC4_T(0);
if (packed_idx < packed_dim_size) {
texel.x = SCALAR_T(buffer_in.data[buf_indices.x]);
}
if (packed_idx + 1 < packed_dim_size) {
texel.y = SCALAR_T(buffer_in.data[buf_indices.y]);
}
if (packed_idx + 2 < packed_dim_size) {
texel.z = SCALAR_T(buffer_in.data[buf_indices.z]);
}
if (packed_idx + 3 < packed_dim_size) {
texel.w = SCALAR_T(buffer_in.data[buf_indices.w]);
}

imageStore(image_out, ${get_pos[NDIM]("pos")}, texel);
Expand Down
28 changes: 13 additions & 15 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def get_binary_elementwise_inputs():
return VkTestSuite(
test_suite = VkTestSuite(
[
((M1, M2), (M1, M2)),
((M1, M2), (M1, 1), 2.0),
Expand All @@ -31,6 +31,11 @@ def get_binary_elementwise_inputs():
((S, S1, S2), (S, 1, S2), 2.0),
]
)
test_suite.layouts = [
"api::kWidthPacked",
"api::kChannelsPacked",
]
return test_suite


def get_mm_inputs():
Expand All @@ -41,6 +46,12 @@ def get_mm_inputs():
],
)
test_suite.prepacked_args = ["mat2"]
# ATen matmul doesn't support half
test_suite.dtypes = ["at::kFloat"]
test_suite.layouts = [
"api::kWidthPacked",
"api::kChannelsPacked",
]
return test_suite


Expand All @@ -50,7 +61,6 @@ def get_pool2d_inputs():
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand Down Expand Up @@ -114,7 +124,6 @@ def get_conv2d_inputs():
),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -123,10 +132,9 @@ def get_native_layer_norm_inputs():
[
((S1, S2), [S2], (S2), (S2), 0.001),
((M, M1, M2), [M2], (M2), (M2), 0.001),
((L, XL, M1, M2), [M2], (M2), (M2), 0.001),
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -138,7 +146,6 @@ def get_full_inputs():
([L, M, M1, M2], 2.72),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -161,7 +168,6 @@ def get_select_int_inputs():
((8, 6, 1, 1), 1, 4),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


Expand All @@ -177,11 +183,3 @@ def get_select_int_inputs():
"aten.full.default": get_full_inputs(),
"aten.select.int": get_select_int_inputs(),
}

prepacked_args = {"aten.mm.default": {"mat2"}}

support_exceptions = {
"aten.max_pool2d_with_indices.default": {
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
},
}
22 changes: 21 additions & 1 deletion backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets(is_fbcode = False):
Expand Down Expand Up @@ -43,6 +45,24 @@ def define_common_targets(is_fbcode = False):
default_outs = ["."],
)

pt_operator_library(
name = "all_aten_ops",
check_decl = False,
include_all_operators = True,
)

runtime.cxx_library(
name = "all_aten_ops_lib",
srcs = [],
define_static_target = False,
exported_deps = get_pt_ops_deps(
name = "pt_ops_full",
deps = [
":all_aten_ops",
],
),
)

runtime.cxx_binary(
name = "compute_graph_op_tests_bin",
srcs = [
Expand All @@ -52,7 +72,7 @@ def define_common_targets(is_fbcode = False):
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
":all_aten_ops_lib",
],
)

Expand Down
53 changes: 35 additions & 18 deletions backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@

@dataclass
class VkTestSuite(TestSuite):
supports = {
"storage_types": ["api::StorageType::TEXTURE_3D"],
"layouts": [
"api::GPUMemoryLayout::TENSOR_WIDTH_PACKED",
"api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED",
],
}
def __init__(self, input_cases: List[Any]):
super().__init__(input_cases)
self.storage_types: List[str] = ["api::kTexture3D"]
self.layouts: List[str] = ["api::kChannelsPacked"]


##########################
Expand Down Expand Up @@ -88,7 +85,6 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
self.dot = "->"

self.args = []
self.out = None
self.refs = {}

self.should_prepack = False
Expand Down Expand Up @@ -288,6 +284,7 @@ def set_output(self, ref: ValueRefList) -> str:
return ret_str

def virtual_resize(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
if self.prepack_ref(ref):
return ""
Expand All @@ -296,6 +293,7 @@ def virtual_resize(self, ref: ValueRefList) -> str:
return ret_str

def copy_into_staging(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
if self.prepack_ref(ref):
return ""
Expand Down Expand Up @@ -336,7 +334,7 @@ def check_graph_out(self, ref: ValueRefList) -> str:
ret_str += self.check_graph_out(r)
return ret_str

return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}));\n"
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));\n"

## Top level code generation

Expand Down Expand Up @@ -374,11 +372,19 @@ def gen_graph_exec_code(self) -> str:

return graph_exec

def gen_conditional_skips(self) -> str:
skips = "if (test_dtype == at::kHalf && "
skips += f"!{self.graph}{self.dot}context()->adapter_ptr()->has_16bit_storage()) {{\n"
skips += " GTEST_SKIP();"
skips += "}\n"
return skips

def gen_op_check_fn(self) -> str:
op_name = self.f.func.name.unambiguous_name()
op_check_fn = self.gen_decl(f"check_{op_name}") + " {"
if self.should_prepack:
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {"
op_check_fn += self.gen_conditional_skips()
op_check_fn += self.gen_graph_build_code()
op_check_fn += self.gen_graph_exec_code()
op_check_fn += self.check_graph_out(self.refs["out"])
Expand All @@ -391,19 +397,26 @@ def gen_op_check_fn(self) -> str:
##################################

test_fixture_template = """
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<api::StorageType, api::GPUMemoryLayout>> {{
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
protected:
ComputeGraph* graph;
at::ScalarType test_dtype = at::kFloat;
float rtol = 1e-5;
float atol = 1e-5;
void SetUp() override {{
GraphConfig config;
api::StorageType default_storage_type;
api::GPUMemoryLayout default_memory_layout;
std::tie(default_storage_type, default_memory_layout) = GetParam();
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
config.setStorageTypeOverride(default_storage_type);
config.setMemoryLayoutOverride(default_memory_layout);
graph = new ComputeGraph(config);
if (test_dtype == at::kHalf) {{
rtol = 1e-2;
atol = 1e-2;
}}
}}
void TearDown() override {{
Expand All @@ -420,7 +433,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple


class VkTestSuiteGen(TestSuiteGen):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: List[Any]):
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
super().__init__(f, inputs)
self.op_reg_name = op_reg_name
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
Expand All @@ -442,14 +455,16 @@ def generate_fixture_cpp(self) -> str:
)

def gen_parameterization(self) -> str:
storage_types = self.suite_def.supports["storage_types"]
layouts = self.suite_def.supports["layouts"]
dtypes = self.suite_def.dtypes
storage_types = self.suite_def.storage_types
layouts = self.suite_def.layouts

return f"""
INSTANTIATE_TEST_SUITE_P(
StorageLayoutCombos_{self.op_name},
Combos_{self.op_name},
GeneratedOpsTest_{self.op_name},
::testing::Combine(
::testing::Values({', '.join(dtypes)}),
::testing::Values({', '.join(storage_types)}),
::testing::Values({', '.join(layouts)})));
"""
Expand Down Expand Up @@ -494,9 +509,11 @@ def gen_parameterization(self) -> str:
return true;
}
bool is_close = at::allclose(t1, t2, rtol, atol);
if (!is_close) {
std::cout << "t1:" << t1 << std::endl;
std::cout << "t2:" << t2 << std::endl;
if (!is_close && t1.numel() < 500) {
std::cout << "reference: " << std::endl;
print(t1, 150);
std::cout << "vulkan: " << std::endl;
print(t2, 150);
}
return is_close;
}
Expand Down
Loading