Skip to content

Commit 958afe1

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
vTensor cleanup 3/N - Introduce conversion constructors for vec types (#5423)
Summary: Pull Request resolved: #5423 ## Context Introduce implicit conversion functions to `vec` types. This allows the following pattern: ```cpp utils::ivec3 v1{4, 5, 2}; utils::uvec3 v2 = v1; ``` Whereas before, we would have to do ```cpp utils::ivec3 v1{4, 5, 2}; utils::uvec3 v2( safe_downcast<uint32_t>(v1[0], safe_downcast<uint32_t>(v1[1], safe_downcast<uint32_t>(v1[2]); ``` The connection with `vTensor` class cleanup is that this will allow for consolidation of `vTensor` class methods, specifically circumventing the need to provide two functions to retrieve the logical limits of the `Tensor`, one to retrieve it as a `ivec3` and another to retrieve it as a `uvec3`. ghstack-source-id: 243309910 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D62878650 fbshipit-source-id: 8ed9e8ecfcc35aaf8ba6277fcf3e56c97e8fe8c7
1 parent ebff33c commit 958afe1

File tree

2 files changed

+85
-19
lines changed

2 files changed

+85
-19
lines changed

backends/vulkan/runtime/utils/VecUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,28 @@ struct vec final {
238238
// NOLINTNEXTLINE
239239
Type data[N];
240240

241+
vec() = default;
242+
243+
// Standard constructor with initializer list
244+
vec(std::initializer_list<Type> values) {
245+
VK_CHECK_COND(values.size() == N);
246+
std::copy(values.begin(), values.end(), data);
247+
}
248+
249+
// Conversion constructor from an _integral_ vec type. Note that this is only
250+
// defined if `OtherType` is an integral type to disallow implicit narrowing.
251+
template <
252+
typename OtherType,
253+
typename std::enable_if<
254+
!std::is_same<Type, OtherType>::value &&
255+
std::is_integral<OtherType>::value,
256+
int>::type = 0>
257+
/* implicit */ vec(const vec<OtherType, N>& other) {
258+
for (int i = 0; i < N; ++i) {
259+
data[i] = safe_downcast<Type>(other[i]);
260+
}
261+
}
262+
241263
const Type& operator[](const uint32_t& i) const {
242264
VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!");
243265
return data[i];

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -301,26 +301,70 @@ TEST_F(VulkanComputeAPITest, virtual_transpose_test) {
301301
}
302302
}
303303

304+
utils::ivec3 make_temp_ivec3(int x, int y, int z) {
305+
return utils::ivec3{x, y, z};
306+
}
307+
304308
TEST_F(VulkanComputeAPITest, vec_test) {
305-
utils::vec3 v3({1, 2, 3});
306-
ASSERT_TRUE(v3[0] == 1);
307-
ASSERT_TRUE(v3[1] == 2);
308-
ASSERT_TRUE(v3[2] == 3);
309-
v3 = {4, 5, 6};
310-
ASSERT_TRUE(v3[0] == 4);
311-
ASSERT_TRUE(v3[1] == 5);
312-
ASSERT_TRUE(v3[2] == 6);
313-
314-
utils::uvec4 uv4({4, 3, 2, 1});
315-
ASSERT_TRUE(uv4[0] == 4);
316-
ASSERT_TRUE(uv4[1] == 3);
317-
ASSERT_TRUE(uv4[2] == 2);
318-
ASSERT_TRUE(uv4[3] == 1);
319-
uv4 = {11, 13, 12, 88};
320-
ASSERT_TRUE(uv4[0] == 11);
321-
ASSERT_TRUE(uv4[1] == 13);
322-
ASSERT_TRUE(uv4[2] == 12);
323-
ASSERT_TRUE(uv4[3] == 88);
309+
{
310+
utils::vec3 v3({1, 2, 3});
311+
ASSERT_TRUE(v3[0] == 1);
312+
ASSERT_TRUE(v3[1] == 2);
313+
ASSERT_TRUE(v3[2] == 3);
314+
v3 = {4, 5, 6};
315+
ASSERT_TRUE(v3[0] == 4);
316+
ASSERT_TRUE(v3[1] == 5);
317+
ASSERT_TRUE(v3[2] == 6);
318+
}
319+
320+
{
321+
utils::uvec4 uv4({4, 3, 2, 1});
322+
ASSERT_TRUE(uv4[0] == 4);
323+
ASSERT_TRUE(uv4[1] == 3);
324+
ASSERT_TRUE(uv4[2] == 2);
325+
ASSERT_TRUE(uv4[3] == 1);
326+
uv4 = {11, 13, 12, 88};
327+
ASSERT_TRUE(uv4[0] == 11);
328+
ASSERT_TRUE(uv4[1] == 13);
329+
ASSERT_TRUE(uv4[2] == 12);
330+
ASSERT_TRUE(uv4[3] == 88);
331+
}
332+
333+
// Test copy from same type
334+
{
335+
utils::ivec3 v{5, 6, 8};
336+
utils::ivec3 v2 = v;
337+
338+
ASSERT_TRUE(v2[0] == 5);
339+
ASSERT_TRUE(v2[1] == 6);
340+
ASSERT_TRUE(v2[2] == 8);
341+
}
342+
343+
// Test copy from different type
344+
{
345+
utils::uvec3 v{5, 6, 8};
346+
utils::ivec3 v2 = v;
347+
348+
ASSERT_TRUE(v2[0] == 5);
349+
ASSERT_TRUE(v2[1] == 6);
350+
ASSERT_TRUE(v2[2] == 8);
351+
}
352+
353+
// Test construction from temporary vec
354+
{
355+
utils::uvec3 v{make_temp_ivec3(4, 5, 10)};
356+
ASSERT_TRUE(v[0] == 4);
357+
ASSERT_TRUE(v[1] == 5);
358+
ASSERT_TRUE(v[2] == 10);
359+
}
360+
361+
// Test initalization from temporary vec
362+
{
363+
utils::uvec3 v = make_temp_ivec3(4, 5, 10);
364+
ASSERT_TRUE(v[0] == 4);
365+
ASSERT_TRUE(v[1] == 5);
366+
ASSERT_TRUE(v[2] == 10);
367+
}
324368
}
325369

326370
TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) {

0 commit comments

Comments
 (0)