Skip to content

Commit 158e487

Browse files
committed
Update base for Update on "[ET-VK][AOT][ez] Introduce vulkan export utils lib"
## Changes As title. Introduce a common Python utility library for scripts in the Vulkan backend. Differential Revision: [D65291064](https://our.internmc.facebook.com/intern/diff/D65291064/) [ghstack-poisoned]
2 parents 248a3f6 + 1972e69 commit 158e487

File tree

9 files changed

+292
-66
lines changed

9 files changed

+292
-66
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
7373
)
7474
# This pass assumes that the SpecPropPass() has already been applied
7575
assert "spec" in node.meta
76+
assert node.meta["spec"].const
7677
# Validate that the original node is marked as a constant. Constant tensors
7778
# do not participate in memory planning.
7879
prepack_node.meta["val"] = node.meta["val"]
7980
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
80-
# prepack_node.meta = deepcopy(node.meta)
8181
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
8282
# memory object.
8383
prepack_node.meta["spec"].mem_obj_id = -1

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,27 +91,23 @@ void main() {
9191

9292
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
9393

94-
VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
95-
u16vec3 mat1_pos = u16vec3(0, out_pos.yz);
96-
u16vec3 qmat2_pos = u16vec3(0, out_pos.x * 4, 0);
94+
VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
95+
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
9796

9897
VEC4_T outtex = VEC4_T(0);
9998

10099
const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
101100
const VEC4_T scales = load_texel(t_scales, scales_pos);
102101

103-
for (int i = 0; i < K; i += 4) {
104-
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
102+
for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
103+
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz));
105104
const VEC4_T sums = VEC4_T(
106-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)),
107-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))),
108-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))),
109-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0))));
105+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
106+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
107+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))),
108+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0))));
110109

111110
outtex += sums;
112-
113-
mat1_pos.x++;
114-
qmat2_pos.x++;
115111
}
116112

117113
outtex *= scales;
@@ -120,12 +116,12 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
120116
}
121117

122118
void main() {
123-
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
119+
const u16vec3 out_pos = u16vec3(gl_GlobalInvocationID);
124120
if (any(greaterThanEqual(out_pos, out_limits))) {
125121
return;
126122
}
127123

128-
VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x);
124+
VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
129125
write_texel(t_out, out_pos, outtex);
130126
}
131127

backends/vulkan/targets.bzl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,28 +101,37 @@ def define_common_targets(is_fbcode = False):
101101
"fbsource//third-party/VulkanMemoryAllocator/3.0.1:VulkanMemoryAllocator_xplat",
102102
]
103103

104-
if not is_fbcode:
104+
if is_fbcode:
105105
VK_API_DEPS += [
106-
"fbsource//third-party/volk:volk",
106+
"fbsource//third-party/swiftshader:swiftshader_vk_headers",
107+
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_fbcode",
108+
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_so",
107109
]
110+
else:
108111
VK_API_DEPS += select({
109-
"DEFAULT": [],
110-
"ovr_config//os:android": ["fbsource//third-party/toolchains:android"],
112+
"DEFAULT": [
113+
"fbsource//third-party/volk:volk",
114+
],
115+
"ovr_config//os:android": [
116+
"fbsource//third-party/volk:volk",
117+
"fbsource//third-party/toolchains:android"
118+
],
119+
"ovr_config//os:macos-arm64": [
120+
"//third-party/khronos:moltenVK"
121+
],
111122
})
112-
VK_API_PREPROCESSOR_FLAGS += [
113-
"-DUSE_VULKAN_WRAPPER",
114-
"-DUSE_VULKAN_VOLK",
115-
]
116123
VK_API_PREPROCESSOR_FLAGS += select({
117-
"DEFAULT": [],
118-
"ovr_config//os:android": ["-DVK_ANDROID_external_memory_android_hardware_buffer"],
124+
"DEFAULT": [
125+
"-DUSE_VULKAN_WRAPPER",
126+
"-DUSE_VULKAN_VOLK",
127+
],
128+
"ovr_config//os:android": [
129+
"-DUSE_VULKAN_WRAPPER",
130+
"-DUSE_VULKAN_VOLK",
131+
"-DVK_ANDROID_external_memory_android_hardware_buffer"
132+
],
133+
"ovr_config//os:macos-arm64": []
119134
})
120-
else:
121-
VK_API_DEPS += [
122-
"fbsource//third-party/swiftshader:swiftshader_vk_headers",
123-
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_fbcode",
124-
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_so",
125-
]
126135

127136
runtime.cxx_library(
128137
name = "vulkan_compute_api",

extension/llm/custom_ops/targets.bzl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4+
"get_compiler_optimization_flags",
5+
)
6+
27

38
def define_common_targets():
49
"""Defines targets that should be shared between fbcode and xplat.
@@ -34,7 +39,7 @@ def define_common_targets():
3439
"//executorch/kernels/portable/cpu/util:reduce_util",
3540
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
3641
],
37-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
42+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
3843
visibility = [
3944
"//executorch/...",
4045
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/cpu/binary_ops.h

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44+
kBroadcastNdByNd,
45+
kBroadcastNdByNdReverseArguments,
4446
};
4547

4648
namespace internal {
47-
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
49+
50+
// Find the single broadcast dimension if it exists.
51+
// This path aims to handle broadcast of the following form
52+
// A = [a1, a2,., 1, .., an]
53+
// B = [b1, b2,., bm, .., bn]
54+
// OR
55+
// A = [a1, a2,., am, .., an]
56+
// B = [b1, b2,., 1, .., bn]
57+
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59+
auto lhs_end = lhs.sizes().end();
60+
61+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62+
auto rhs_end = rhs.sizes().end();
63+
64+
const auto lhs_size = lhs_end - lhs_begin;
65+
const auto rhs_size = rhs_end - rhs_begin;
66+
67+
// Following example is not handled at the moment
68+
// [1, 3, 4, 5]
69+
// [2, 3, 4, 5]
70+
if (lhs_size != rhs_size) {
71+
return 0;
72+
}
73+
74+
int32_t broadcast_dim = 0;
75+
// Check
76+
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77+
// 2. If more than one dim value is 1 (we cannot handle)
78+
// 3. If non-1 dim values are equal
79+
lhs_end--;
80+
rhs_end--;
81+
while (lhs_end != lhs_begin) {
82+
if (*lhs_end == 1 || *rhs_end == 1) {
83+
// If more than one broadcast dim is found, return 0.
84+
if (broadcast_dim != 0) {
85+
return 0;
86+
}
87+
// negative index is used
88+
broadcast_dim = lhs_end - lhs.sizes().end();
89+
} else if (*lhs_end != *rhs_end) {
90+
// If non-1 dim values are not equal, return 0.
91+
return 0;
92+
}
93+
lhs_end--;
94+
rhs_end--;
95+
}
96+
return broadcast_dim;
97+
}
98+
99+
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
48100
const Tensor& lhs,
49101
const Tensor& rhs) {
50102
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63115
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
64116
}
65117

118+
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119+
// Right now we dont handle last dim broadcast
120+
if (broadcast_dim < -1) {
121+
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122+
return x == 1;
123+
}) == 1) {
124+
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125+
} else {
126+
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127+
}
128+
}
66129
return ElementwiseOptimizedPath::kNone;
67130
}
68131
} // namespace internal
@@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path(
85148
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
86149
return ElementwiseOptimizedPath::kTreatAs1d;
87150
}
88-
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
151+
return internal::select_broadcast_optimized_path(a, b);
152+
}
153+
154+
std::array<int32_t, 3> inline get_normalized_tensor_size(
155+
const Tensor& a,
156+
const int32_t broadcast_dim) {
157+
ET_CHECK_MSG(
158+
a.dim() > broadcast_dim,
159+
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160+
a.dim(),
161+
broadcast_dim);
162+
std::array<int32_t, 3> normalized_tensor_size;
163+
normalized_tensor_size[0] = 1;
164+
normalized_tensor_size[1] = a.size(broadcast_dim);
165+
normalized_tensor_size[2] = 1;
166+
for (size_t i = 0; i < broadcast_dim; i++) {
167+
normalized_tensor_size[0] *= a.size(i);
168+
}
169+
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170+
normalized_tensor_size[2] *= a.size(i);
171+
}
172+
return normalized_tensor_size;
89173
}
90174

91175
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if (selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
133+
if ((selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135+
(selected_optimized_path ==
136+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
135137
lhs = &b;
136138
rhs = &a;
137139
} else {
138140
// Catch failure to update logic when adding new broadcasting possibility.
139141
ET_DCHECK(
140-
selected_optimized_path ==
141-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
142+
(selected_optimized_path ==
143+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144+
(selected_optimized_path ==
145+
ElementwiseOptimizedPath::kBroadcastNdByNd));
142146
lhs = &a;
143147
rhs = &b;
144148
}
@@ -149,15 +153,34 @@ Tensor& opt_mul_out(
149153
InvalidArgument,
150154
out,
151155
"Failed to resize output tensor.");
156+
int64_t outer_size = 1;
157+
int64_t broadcast_size;
158+
int64_t inner_size;
159+
if ((selected_optimized_path ==
160+
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161+
(selected_optimized_path ==
162+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165+
auto normalized_tensor_size_lhs =
166+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167+
outer_size = normalized_tensor_size_lhs[0];
168+
broadcast_size = normalized_tensor_size_lhs[1];
169+
inner_size = normalized_tensor_size_lhs[2];
170+
} else {
171+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172+
inner_size = lhs->sizes()[lhs->dim() - 1];
173+
}
152174
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
153175
using Vec = executorch::vec::Vectorized<CTYPE>;
154-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
176+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
155177
[](Vec x, Vec y) { return x * y; },
156178
out.mutable_data_ptr<CTYPE>(),
157179
lhs->const_data_ptr<CTYPE>(),
158180
rhs->const_data_ptr<CTYPE>(),
159-
lhs->sizes()[lhs->dim() - 2],
160-
lhs->sizes()[lhs->dim() - 1]);
181+
outer_size,
182+
broadcast_size,
183+
inner_size);
161184
});
162185
} else {
163186
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,49 @@ inline void map4(
326326
}
327327

328328

329-
// Map vec_fun across input_data and input_data2, where input_data is
330-
// a two-dimensional array of size (size, size2), input_data2 is a
331-
// one-dimensional array of size size2, and input_data2 is broadcast
332-
// to be of size (size, size2).
329+
// This function implements broadcasting binary operation on two tensors
330+
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331+
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332+
// And this 1st dimension is considered broadcasting dimension
333+
// This formula can map broadcasting on any dim=broadcast_dim
334+
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335+
template <typename scalar_t, typename Op>
336+
inline void broadcasting_map_3d_and_unsqueezed_3d(
337+
const Op& vec_fun,
338+
scalar_t* output_data,
339+
const scalar_t* lhs,
340+
const scalar_t* rhs,
341+
int64_t outer_size,
342+
int64_t broadcast_size,
343+
int64_t inner_size) {
344+
using Vec = vec::Vectorized<scalar_t>;
345+
int64_t outer_stride_lhs = inner_size * broadcast_size;
346+
int64_t outer_stride_rhs = inner_size;
347+
int64_t broadcast_stride_lhs = inner_size;
348+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351+
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352+
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353+
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354+
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355+
int64_t inner_idx = 0;
356+
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359+
Vec output_vec = vec_fun(data_vec, data_vec2);
360+
output_vec.store(output_data_row_2 + inner_idx);
361+
}
362+
if (inner_size - inner_idx > 0) {
363+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365+
Vec output_vec = vec_fun(data_vec, data_vec2);
366+
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367+
}
368+
}
369+
}
370+
}
371+
333372
template <typename scalar_t, typename Op>
334373
inline void broadcasting_map_2d_by_1d(
335374
const Op& vec_fun,
@@ -338,27 +377,8 @@ inline void broadcasting_map_2d_by_1d(
338377
const scalar_t* input_data2,
339378
int64_t size,
340379
int64_t size2) {
341-
using Vec = vec::Vectorized<scalar_t>;
342-
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343-
const scalar_t* input_data_row = input_data + outer_idx * size2;
344-
scalar_t* output_data_row = output_data + outer_idx * size2;
345-
int64_t inner_idx = 0;
346-
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347-
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349-
Vec output_vec = vec_fun(data_vec, data_vec2);
350-
output_vec.store(output_data_row + inner_idx);
351-
}
352-
if (size2 - inner_idx > 0) {
353-
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355-
Vec output_vec = vec_fun(data_vec, data_vec2);
356-
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357-
}
358-
}
380+
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
359381
}
360382

361-
362-
363383
} // namespace vec
364384
} // namespace executorch

0 commit comments

Comments
 (0)