Skip to content

Commit a5e7633

Browse files
committed
[ET-VK][Ez] Introduce convenience constexpr for StorageTypes and GPUMemoryLayouts
Pull Request resolved: #2948 ## Context Introduce the following convenience `constexpr`: * `api::kBuffer`, `api::kTexture3D`, and `api::kTexture2D` * `api::kWidthPacked`, `api::kHeightPacked`, and `api::kChannelsPacked` Also remove the `api::StorageType::UNKNOWN` enum entry as it doesn't really serve any purpose. ghstack-source-id: 221871428 Differential Revision: [D55811278](https://our.internmc.facebook.com/intern/diff/D55811278/)
1 parent 55cb116 commit a5e7633

File tree

16 files changed

+95
-144
lines changed

16 files changed

+95
-144
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,26 +77,26 @@ api::StorageType get_storage_type(
7777
const vkgraph::VkStorageType& vk_storage_type) {
7878
switch (vk_storage_type) {
7979
case vkgraph::VkStorageType::BUFFER:
80-
return api::StorageType::BUFFER;
80+
return api::kBuffer;
8181
case vkgraph::VkStorageType::TEXTURE_3D:
82-
return api::StorageType::TEXTURE_3D;
82+
return api::kTexture3D;
8383
case vkgraph::VkStorageType::TEXTURE_2D:
84-
return api::StorageType::TEXTURE_2D;
84+
return api::kTexture2D;
8585
default:
8686
break;
8787
}
88-
return api::StorageType::UNKNOWN;
88+
VK_THROW("Invalid storage type encountered!");
8989
}
9090

9191
api::GPUMemoryLayout get_memory_layout(
9292
const vkgraph::VkMemoryLayout& vk_memory_layout) {
9393
switch (vk_memory_layout) {
9494
case vkgraph::VkMemoryLayout::TENSOR_WIDTH_PACKED:
95-
return api::GPUMemoryLayout::TENSOR_WIDTH_PACKED;
95+
return api::kWidthPacked;
9696
case vkgraph::VkMemoryLayout::TENSOR_HEIGHT_PACKED:
97-
return api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED;
97+
return api::kHeightPacked;
9898
case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED:
99-
return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
99+
return api::kChannelsPacked;
100100
default:
101101
break;
102102
}

backends/vulkan/runtime/api/Shader.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,19 @@ ShaderInfo::ShaderInfo()
2323
0u,
2424
} {}
2525

26-
ShaderInfo::ShaderInfo(
27-
std::string name,
28-
const uint32_t* const spirv_bin,
29-
const uint32_t size,
30-
std::vector<VkDescriptorType> layout)
31-
: src_code{
32-
spirv_bin,
33-
size,
34-
},
35-
kernel_name{std::move(name)},
36-
kernel_layout{std::move(layout)} {}
37-
3826
ShaderInfo::ShaderInfo(
3927
std::string name,
4028
const uint32_t* const spirv_bin,
4129
const uint32_t size,
4230
std::vector<VkDescriptorType> layout,
43-
const std::vector<uint32_t>& tile_size,
44-
const StorageType bias_storage_type,
45-
const StorageType weight_storage_type)
31+
const utils::uvec3 tile_size)
4632
: src_code{
4733
spirv_bin,
4834
size,
4935
},
5036
kernel_name{std::move(name)},
5137
kernel_layout{std::move(layout)},
52-
tile_size(tile_size),
53-
bias_storage_type(bias_storage_type),
54-
weight_storage_type(weight_storage_type) {
55-
for (uint64_t i = 0; i < tile_size.size(); ++i) {
56-
out_tile_size.data[i] = tile_size[i];
57-
}
38+
out_tile_size(tile_size) {
5839
}
5940

6041
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {

backends/vulkan/runtime/api/Shader.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,14 @@ struct ShaderInfo final {
6262
// Shader Metadata
6363
utils::uvec3 out_tile_size{1u, 1u, 1u};
6464

65-
std::vector<uint32_t> tile_size;
66-
StorageType bias_storage_type{StorageType::UNKNOWN};
67-
StorageType weight_storage_type{StorageType::UNKNOWN};
68-
6965
explicit ShaderInfo();
70-
explicit ShaderInfo(std::string, const char*);
71-
explicit ShaderInfo(
72-
std::string,
73-
const uint32_t*,
74-
const uint32_t,
75-
std::vector<VkDescriptorType>);
66+
7667
explicit ShaderInfo(
7768
std::string,
7869
const uint32_t*,
7970
const uint32_t,
8071
std::vector<VkDescriptorType>,
81-
const std::vector<uint32_t>& tile_size,
82-
const StorageType bias_storage_type,
83-
const StorageType weight_storage_type);
72+
const utils::uvec3 tile_size);
8473
};
8574

8675
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);

backends/vulkan/runtime/api/Tensor.cpp

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,20 @@ std::vector<int64_t> calc_strides(
6767
const api::GPUMemoryLayout memory_layout,
6868
const api::StorageType storage_type) {
6969
switch (storage_type) {
70-
case api::StorageType::BUFFER:
70+
case api::kBuffer:
7171
switch (memory_layout) {
72-
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
72+
case api::kWidthPacked:
7373
return calc_contiguous_strides(sizes);
7474
break;
75-
case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
75+
case api::kChannelsPacked:
7676
return calc_channels_last_strides(sizes);
7777
break;
7878
default:
7979
VK_THROW("Invalid memory format used to create vTensor!");
8080
}
8181
break;
82-
case api::StorageType::TEXTURE_3D:
83-
case api::StorageType::TEXTURE_2D:
82+
case api::kTexture3D:
83+
case api::kTexture2D:
8484
return std::vector<int64_t>(sizes.size());
8585
default:
8686
VK_THROW("Invalid storage type used to create vTensor!");
@@ -99,10 +99,8 @@ std::vector<int64_t> calc_gpu_sizes(
9999
const std::vector<int64_t>& sizes,
100100
const api::GPUMemoryLayout memory_layout,
101101
const api::StorageType storage_type) {
102-
VK_CHECK_COND(storage_type != api::StorageType::UNKNOWN);
103-
104102
std::vector<int64_t> gpu_sizes;
105-
if (storage_type == api::StorageType::BUFFER) {
103+
if (storage_type == api::kBuffer) {
106104
gpu_sizes.resize(sizes.size());
107105
for (size_t i = 0; i < sizes.size(); i++) {
108106
gpu_sizes.at(i) = sizes.at(i);
@@ -127,21 +125,21 @@ std::vector<int64_t> calc_gpu_sizes(
127125

128126
size_t ndim = gpu_sizes.size();
129127
switch (memory_layout) {
130-
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
128+
case api::kWidthPacked:
131129
if (ndim >= 1) {
132130
gpu_sizes.at(ndim - 1) =
133131
api::utils::align_up(api::utils::val_at(-1, sizes), INT64_C(4));
134132
}
135133
break;
136134

137-
case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
135+
case api::kHeightPacked:
138136
if (ndim >= 2) {
139137
gpu_sizes.at(ndim - 2) =
140138
api::utils::align_up(api::utils::val_at(-2, sizes), INT64_C(4));
141139
}
142140
break;
143141

144-
case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
142+
case api::kChannelsPacked:
145143
if (ndim >= 3) {
146144
gpu_sizes.at(ndim - 3) =
147145
api::utils::align_up(api::utils::val_at(-3, sizes), INT64_C(4));
@@ -162,7 +160,7 @@ api::utils::uvec3 create_image_extents(
162160
const api::GPUMemoryLayout memory_layout) {
163161
size_t ndim = gpu_sizes.size();
164162

165-
if (storage_type == api::StorageType::BUFFER) {
163+
if (storage_type == api::kBuffer) {
166164
// image extents do not apply to buffer storage
167165
return {0u, 0u, 0u};
168166
} else {
@@ -177,15 +175,15 @@ api::utils::uvec3 create_image_extents(
177175
uint32_t batch = safe_downcast<uint32_t>(val_at(-4, gpu_sizes));
178176

179177
switch (memory_layout) {
180-
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
178+
case api::kWidthPacked:
181179
VK_CHECK_COND(width % 4 == 0, "Channels must be divisible by 4!");
182180
width /= 4;
183181
break;
184-
case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
182+
case api::kHeightPacked:
185183
VK_CHECK_COND(height % 4 == 0, "Channels must be divisible by 4!");
186184
height /= 4;
187185
break;
188-
case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
186+
case api::kChannelsPacked:
189187
VK_CHECK_COND(channels % 4 == 0, "Channels must be divisible by 4!");
190188
channels /= 4;
191189
break;
@@ -326,41 +324,35 @@ std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() {
326324

327325
VmaAllocationCreateInfo vTensor::get_allocation_create_info() const {
328326
switch (storage_type()) {
329-
case api::StorageType::BUFFER:
327+
case api::kBuffer:
330328
return view_->buffer_.allocation_create_info();
331-
case api::StorageType::TEXTURE_2D:
332-
case api::StorageType::TEXTURE_3D:
329+
case api::kTexture2D:
330+
case api::kTexture3D:
333331
return view_->image_.allocation_create_info();
334-
case api::StorageType::UNKNOWN:
335-
break;
336332
}
337333
return {};
338334
}
339335

340336
VkMemoryRequirements vTensor::get_memory_requirements() const {
341337
switch (storage_type()) {
342-
case api::StorageType::BUFFER:
338+
case api::kBuffer:
343339
return view_->buffer_.get_memory_requirements();
344-
case api::StorageType::TEXTURE_2D:
345-
case api::StorageType::TEXTURE_3D:
340+
case api::kTexture2D:
341+
case api::kTexture3D:
346342
return view_->image_.get_memory_requirements();
347-
case api::StorageType::UNKNOWN:
348-
break;
349343
}
350344
return {};
351345
}
352346

353347
void vTensor::bind_allocation(const api::MemoryAllocation& allocation) {
354348
switch (storage_type()) {
355-
case api::StorageType::BUFFER:
349+
case api::kBuffer:
356350
view_->buffer_.bind_allocation(allocation);
357351
break;
358-
case api::StorageType::TEXTURE_2D:
359-
case api::StorageType::TEXTURE_3D:
352+
case api::kTexture2D:
353+
case api::kTexture3D:
360354
view_->image_.bind_allocation(allocation);
361355
break;
362-
case api::StorageType::UNKNOWN:
363-
break;
364356
}
365357
}
366358

@@ -397,7 +389,7 @@ void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
397389

398390
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
399391
update_size_metadata(new_sizes);
400-
if (storage_type() == api::StorageType::BUFFER) {
392+
if (storage_type() == api::kBuffer) {
401393
if (gpu_nbytes() > view_->buffer_.mem_size()) {
402394
VK_THROW(
403395
"Cannot virtual_resize a vTensor with sizes that require a larger "
@@ -446,11 +438,11 @@ api::VulkanImage allocate_image(
446438
VkImageViewType image_view_type = VK_IMAGE_VIEW_TYPE_3D;
447439

448440
switch (storage_type) {
449-
case api::StorageType::TEXTURE_3D:
441+
case api::kTexture3D:
450442
image_type = VK_IMAGE_TYPE_3D;
451443
image_view_type = VK_IMAGE_VIEW_TYPE_3D;
452444
break;
453-
case api::StorageType::TEXTURE_2D:
445+
case api::kTexture2D:
454446
image_type = VK_IMAGE_TYPE_2D;
455447
image_view_type = VK_IMAGE_VIEW_TYPE_2D;
456448
break;
@@ -481,7 +473,7 @@ api::VulkanBuffer allocate_buffer(
481473
api::Adapter* adapter_ptr = context_ptr->adapter_ptr();
482474

483475
switch (storage_type) {
484-
case api::StorageType::BUFFER:
476+
case api::kBuffer:
485477
break;
486478
default:
487479
// Return an empty VulkanBuffer if Buffer storage is not used

backends/vulkan/runtime/api/Tensor.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ class vTensor final {
103103
api::Context* context,
104104
const std::vector<int64_t>& sizes,
105105
const api::ScalarType dtype,
106-
const api::StorageType storage_type = api::StorageType::TEXTURE_3D,
107-
const api::GPUMemoryLayout memory_layout =
108-
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
106+
const api::StorageType storage_type = api::kTexture3D,
107+
const api::GPUMemoryLayout memory_layout = api::kChannelsPacked,
109108
const bool allocate_memory = true);
110109

111110
// Default constructor for quantized vTensor
@@ -115,9 +114,8 @@ class vTensor final {
115114
double q_scale,
116115
int64_t q_zero_point,
117116
const api::ScalarType dtype,
118-
const api::StorageType storage_type = api::StorageType::TEXTURE_3D,
119-
const api::GPUMemoryLayout memory_layout =
120-
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
117+
const api::StorageType storage_type = api::kTexture3D,
118+
const api::GPUMemoryLayout memory_layout = api::kChannelsPacked);
121119

122120
// Copy Constructor and Assignment; Ideally copying would be disabled
123121
// (see the reasoning for move assignment below) but it is required for

backends/vulkan/runtime/api/Types.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,16 @@ VK_FORALL_SCALAR_TYPES(SPECIALIZE_ScalarTypeToCType)
162162
*
163163
* UNKNOWN is not expected to be used.
164164
*/
165-
enum class StorageType {
165+
enum class StorageType : uint8_t {
166166
BUFFER,
167167
TEXTURE_3D,
168168
TEXTURE_2D,
169-
UNKNOWN,
170169
};
171170

171+
static constexpr StorageType kBuffer = StorageType::BUFFER;
172+
static constexpr StorageType kTexture3D = StorageType::TEXTURE_3D;
173+
static constexpr StorageType kTexture2D = StorageType::TEXTURE_2D;
174+
172175
/**
173176
* The enum below is used to describe how tensor data is laid out when stored in
174177
* GPU memory. The name of the enum describes which dimension is tightly packed;
@@ -182,11 +185,20 @@ enum class StorageType {
182185
* strides of the tensor will be used instead to convert between logical tensor
183186
* coordinates and linear access indices.
184187
*/
185-
enum class GPUMemoryLayout : uint32_t {
188+
enum class GPUMemoryLayout : uint8_t {
186189
TENSOR_WIDTH_PACKED = 0u,
187190
TENSOR_HEIGHT_PACKED = 1u,
188191
TENSOR_CHANNELS_PACKED = 2u,
189192
};
190193

194+
static constexpr GPUMemoryLayout kWidthPacked =
195+
GPUMemoryLayout::TENSOR_WIDTH_PACKED;
196+
197+
static constexpr GPUMemoryLayout kHeightPacked =
198+
GPUMemoryLayout::TENSOR_HEIGHT_PACKED;
199+
200+
static constexpr GPUMemoryLayout kChannelsPacked =
201+
GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
202+
191203
} // namespace api
192204
} // namespace vkcompute

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,6 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
543543
r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
544544
}
545545

546-
storageTypeToEnum = {
547-
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
548-
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
549-
"BUFFER": "api::StorageType::BUFFER",
550-
"": "api::StorageType::UNKNOWN",
551-
}
552-
553546

554547
def determineDescriptorType(lineStr: str) -> str:
555548
for identifier, typeNum in typeIdMapping.items():
@@ -632,7 +625,7 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->
632625
tile_size = (
633626
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
634627
if (len(shader_info.tile_size) > 0)
635-
else "std::vector<uint32_t>()"
628+
else "{1, 1, 1}"
636629
)
637630

638631
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
@@ -643,8 +636,6 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->
643636
str(sizeBytes),
644637
shader_info_layouts,
645638
tile_size,
646-
storageTypeToEnum[shader_info.weight_storage_type],
647-
storageTypeToEnum[shader_info.bias_storage_type],
648639
]
649640

650641
shader_info_str = textwrap.indent(

0 commit comments

Comments
 (0)