Skip to content

Commit 31e652d

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Integrate axis mapping into naive matrix multiplication shaders (#5277)
Summary: Pull Request resolved: #5277 ## Context Give similar treatment as #5223 to integrate axis mapping into the naive matrix multiplication shaders. As with the previous diff, code cleanup is performed as well to consolidate shaders and improve code readability. ## Performance impact Running the matrix multiplication operator benchmark, we can observe the following results: | commit | matmul_naive_texture3d_float | linear_naive_texture3d_float | |-------------|------------------------------|------------------------------| | master | 6.53645 | 6.98834 | | this commit | 6.61293 | 6.34905 | Evidently, accounting for axis mapping did not have any significant adverse impact on shader latency. ghstack-source-id: 242452079 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D62518403 fbshipit-source-id: de873164fa6202b9b3312d1f62ff1dc2cec86db8
1 parent 71602a0 commit 31e652d

21 files changed

+416
-592
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ std::vector<int64_t> calculate_strides(
8989
* tensor. Thus the axis mapping can be considered to be in WHCN dimension
9090
* order.
9191
*
92-
* The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
92+
* The last value `axis_map.at(3)` indicates the WHCN index of the tensor
9393
* dimension along which batches will be concatenated. This dimension can be
9494
* referred to as the "inner dimension" To determine which image texture axis is
9595
* used for the concatenation, a double lookup will need to be performed
96-
* (axis_mapping.at(axis_mapping.at(3))).
96+
* (axis_map.at(axis_map.at(3))).
9797
*
9898
* The reason for strucuring axis mapping this way is because for the batch dim,
9999
* two things need to be easily derived:
@@ -107,7 +107,7 @@ std::vector<int64_t> calculate_strides(
107107
*
108108
* The axis mapping allows for permuted views of texture-backed tensors.
109109
*/
110-
std::vector<int64_t> default_axis_mapping() {
110+
std::vector<int64_t> default_axis_map() {
111111
// Currently, all compute shaders have an assumption that the channels dim is
112112
// used to combine with the batch dim of a tensor. However, once dim mapping
113113
// is integrated into the tensor indexing logic for each compute shader, we
@@ -173,40 +173,40 @@ std::vector<int64_t> calculate_padded_sizes(
173173

174174
utils::uvec3 calculate_image_extents(
175175
const std::vector<int64_t>& padded_sizes,
176-
const std::vector<int64_t>& axis_mapping,
176+
const std::vector<int64_t>& axis_map,
177177
const utils::GPUMemoryLayout memory_layout) {
178178
VK_CHECK_COND(padded_sizes.size() == 4);
179-
VK_CHECK_COND(axis_mapping.size() == 4);
179+
VK_CHECK_COND(axis_map.size() == 4);
180180

181181
utils::uvec3 extents({1, 1, 1});
182-
// First three elements of axis_mapping indicate which (X,Y,Z) image axis the
182+
// First three elements of axis_map indicate which (X,Y,Z) image axis the
183183
// width, height, and channels dim of the tensor maps to.
184184
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
185-
const int64_t axis = axis_mapping.at(whcn_dim);
185+
const int64_t axis = axis_map.at(whcn_dim);
186186
const int64_t dim = padded_sizes.size() - 1 - whcn_dim;
187187
extents[axis] = utils::safe_downcast<uint32_t>(padded_sizes.at(dim));
188188
}
189189

190-
// axis_mapping[3] indicates the WHCN index of the dimension used for batch
190+
// axis_map[3] indicates the WHCN index of the dimension used for batch
191191
// concatenation. Thus a double lookup is required to determine the image axis
192192
// used for batch concatenation.
193-
const int64_t concatted_whcn_dim = axis_mapping.at(3);
194-
const int64_t batch_axis = axis_mapping.at(concatted_whcn_dim);
193+
const int64_t concatted_whcn_dim = axis_map.at(3);
194+
const int64_t batch_axis = axis_map.at(concatted_whcn_dim);
195195
// Multiply the extents of the batch axis by the batch size.
196196
extents[batch_axis] *= padded_sizes.at(0);
197197

198198
switch (memory_layout) {
199199
case utils::kWidthPacked:
200-
VK_CHECK_COND(extents[0] % 4 == 0);
201-
extents[0] /= 4;
200+
VK_CHECK_COND(extents[axis_map.at(0)] % 4 == 0);
201+
extents[axis_map.at(0)] /= 4;
202202
break;
203203
case utils::kHeightPacked:
204-
VK_CHECK_COND(extents[1] % 4 == 0);
205-
extents[1] /= 4;
204+
VK_CHECK_COND(extents[axis_map.at(1)] % 4 == 0);
205+
extents[axis_map.at(1)] /= 4;
206206
break;
207207
case utils::kChannelsPacked:
208-
VK_CHECK_COND(extents[2] % 4 == 0);
209-
extents[2] /= 4;
208+
VK_CHECK_COND(extents[axis_map.at(2)] % 4 == 0);
209+
extents[axis_map.at(2)] /= 4;
210210
break;
211211
}
212212

@@ -229,25 +229,27 @@ vTensor::vTensor(
229229
// Calculate tensor metadata
230230
sizes_(sizes.begin(), sizes.end()),
231231
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232-
axis_mapping_(default_axis_mapping()),
232+
axis_map_(default_axis_map()),
233233
strides_(calculate_strides(sizes, dim_order_)),
234234
numel_(utils::multiply_integers(sizes_)),
235235
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
236236
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
237237
padded_numel_(utils::multiply_integers(padded_sizes_)),
238238
texture_limits_{{0, 0, 0}},
239+
logical_limits_{{0, 0, 0}},
239240
// Utility Uniform Buffers that can be passed to shaders as arguments
240241
sizes_uniform_(),
241242
strides_uniform_(),
242243
numel_uniform_(),
243-
axis_mapping_uniform_(),
244+
axis_map_uniform_(),
244245
texture_limits_uniform_(),
246+
logical_limits_uniform_(),
245247
// Construct Tensor storage
246248
storage_(
247249
context,
248250
storage_type,
249251
memory_layout_,
250-
axis_mapping_,
252+
axis_map_,
251253
padded_sizes_,
252254
dtype_,
253255
allocate_memory) {
@@ -259,6 +261,8 @@ vTensor::vTensor(
259261
utils::safe_downcast<int32_t>(storage_.image_extents_[0]),
260262
utils::safe_downcast<int32_t>(storage_.image_extents_[1]),
261263
utils::safe_downcast<int32_t>(storage_.image_extents_[2])};
264+
265+
update_logical_limits();
262266
}
263267

264268
if (dtype == vkapi::kHalf) {
@@ -275,7 +279,7 @@ vTensor::vTensor(const vTensor& other)
275279
// Copy tensor size metadata
276280
sizes_(other.sizes_.begin(), other.sizes_.end()),
277281
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278-
axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
282+
axis_map_(other.axis_map_.begin(), other.axis_map_.end()),
279283
strides_(other.strides_.begin(), other.strides_.end()),
280284
numel_(other.numel_),
281285
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
@@ -284,12 +288,14 @@ vTensor::vTensor(const vTensor& other)
284288
other.unsqueezed_strides_.end()},
285289
padded_numel_(other.padded_numel_),
286290
texture_limits_{other.texture_limits_},
291+
logical_limits_{other.logical_limits_},
287292
// Empty initialize Utility Uniform Buffers
288293
sizes_uniform_(),
289294
strides_uniform_(),
290295
numel_uniform_(),
291-
axis_mapping_uniform_(),
296+
axis_map_uniform_(),
292297
texture_limits_uniform_(),
298+
logical_limits_uniform_(),
293299
// Copy Tensor storage
294300
storage_(other.storage_) {}
295301

@@ -303,19 +309,21 @@ vTensor::vTensor(
303309
// Copy tensor size metadata
304310
sizes_(sizes.begin(), sizes.end()),
305311
dim_order_(dim_order.begin(), dim_order.end()),
306-
axis_mapping_(default_axis_mapping()),
312+
axis_map_(default_axis_map()),
307313
strides_(calculate_strides(sizes_, dim_order_)),
308314
numel_(utils::multiply_integers(sizes_)),
309315
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
310316
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
311317
padded_numel_(utils::multiply_integers(padded_sizes_)),
312-
texture_limits_{{0, 0, 0}},
318+
texture_limits_{other.texture_limits_},
319+
logical_limits_(other.logical_limits_),
313320
// Empty initialize Utility Uniform Buffers
314321
sizes_uniform_(),
315322
strides_uniform_(),
316323
numel_uniform_(),
317-
axis_mapping_uniform_(),
324+
axis_map_uniform_(),
318325
texture_limits_uniform_(),
326+
logical_limits_uniform_(),
319327
// Copy Tensor storage
320328
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
321329
VK_CHECK_COND(
@@ -356,12 +364,18 @@ vkapi::VulkanBuffer& vTensor::buffer(
356364
return storage_.buffer_;
357365
}
358366

359-
utils::uvec3 vTensor::mapped_extents() const {
360-
utils::uvec3 m_extents;
361-
m_extents[0] = storage_.image_extents_[axis_mapping_.at(0)];
362-
m_extents[1] = storage_.image_extents_[axis_mapping_.at(1)];
363-
m_extents[2] = storage_.image_extents_[axis_mapping_.at(2)];
364-
return m_extents;
367+
void vTensor::update_logical_limits() {
368+
logical_limits_.limits[0] = texture_limits_.limits[axis_map_.at(0)];
369+
logical_limits_.limits[1] = texture_limits_.limits[axis_map_.at(1)];
370+
logical_limits_.limits[2] = texture_limits_.limits[axis_map_.at(2)];
371+
}
372+
373+
utils::uvec3 vTensor::logical_extents() const {
374+
utils::uvec3 logical_extents(
375+
{utils::safe_downcast<uint32_t>(logical_limits_.limits[0]),
376+
utils::safe_downcast<uint32_t>(logical_limits_.limits[1]),
377+
utils::safe_downcast<uint32_t>(logical_limits_.limits[2])});
378+
return logical_extents;
365379
}
366380

367381
const vkapi::BufferBindInfo vTensor::sizes_ubo() {
@@ -380,12 +394,12 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
380394
return vkapi::BufferBindInfo(strides_uniform_.buffer());
381395
}
382396

383-
const vkapi::BufferBindInfo vTensor::axis_mapping_ubo() {
384-
if (!axis_mapping_uniform_.buffer()) {
385-
axis_mapping_uniform_ =
386-
ParamsBuffer(storage_.context_, utils::make_ivec4(axis_mapping_));
397+
const vkapi::BufferBindInfo vTensor::axis_map_ubo() {
398+
if (!axis_map_uniform_.buffer()) {
399+
axis_map_uniform_ =
400+
ParamsBuffer(storage_.context_, utils::make_ivec4(axis_map_));
387401
}
388-
return vkapi::BufferBindInfo(axis_mapping_uniform_.buffer());
402+
return vkapi::BufferBindInfo(axis_map_uniform_.buffer());
389403
}
390404

391405
const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
@@ -395,6 +409,13 @@ const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
395409
return vkapi::BufferBindInfo(texture_limits_uniform_.buffer());
396410
}
397411

412+
const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
413+
if (!logical_limits_uniform_.buffer()) {
414+
logical_limits_uniform_ = ParamsBuffer(storage_.context_, logical_limits_);
415+
}
416+
return vkapi::BufferBindInfo(logical_limits_uniform_.buffer());
417+
}
418+
398419
const vkapi::BufferBindInfo vTensor::numel_ubo() {
399420
if (!numel_uniform_.buffer()) {
400421
numel_uniform_ = ParamsBuffer(storage_.context_, numel_);
@@ -465,14 +486,16 @@ void vTensor::update_metadata() {
465486
// Calculate the extents of the image texture that would have been required
466487
// for a tensor of the new sizes.
467488
utils::uvec3 virtual_extents =
468-
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
489+
calculate_image_extents(padded_sizes_, axis_map_, memory_layout_);
469490

470491
// Update the texture limits to reflect the new virtual extents.
471492
texture_limits_.limits = utils::ivec3{
472493
utils::safe_downcast<int32_t>(virtual_extents[0]),
473494
utils::safe_downcast<int32_t>(virtual_extents[1]),
474495
utils::safe_downcast<int32_t>(virtual_extents[2])};
475496

497+
update_logical_limits();
498+
476499
if (sizes_uniform_.buffer()) {
477500
sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
478501
}
@@ -482,20 +505,23 @@ void vTensor::update_metadata() {
482505
if (numel_uniform_.buffer()) {
483506
numel_uniform_.update(numel_);
484507
}
485-
if (axis_mapping_uniform_.buffer()) {
486-
axis_mapping_uniform_.update(utils::make_ivec4(axis_mapping_));
508+
if (axis_map_uniform_.buffer()) {
509+
axis_map_uniform_.update(utils::make_ivec4(axis_map_));
487510
}
488511
if (texture_limits_uniform_.buffer()) {
489512
texture_limits_uniform_.update(texture_limits_);
490513
}
514+
if (logical_limits_uniform_.buffer()) {
515+
logical_limits_uniform_.update(logical_limits_);
516+
}
491517
}
492518

493519
void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
494520
if (storage_type() != utils::kBuffer) {
495521
// For texture storage check that the current texture is large enough for
496522
// the new sizes of the tensor.
497523
utils::uvec3 virtual_extents =
498-
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
524+
calculate_image_extents(padded_sizes_, axis_map_, memory_layout_);
499525

500526
bool valid_resize = virtual_extents[0] <= image_extents()[0];
501527
valid_resize = valid_resize && virtual_extents[1] <= image_extents()[1];
@@ -546,7 +572,7 @@ void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
546572
update_metadata();
547573
storage_.discard_and_reallocate(
548574
calculate_padded_sizes(new_sizes, memory_layout_),
549-
axis_mapping_,
575+
axis_map_,
550576
memory_layout_,
551577
dtype_);
552578
}
@@ -624,16 +650,14 @@ vTensorStorage::vTensorStorage(
624650
Context* const context,
625651
const utils::StorageType storage_type,
626652
const utils::GPUMemoryLayout gpu_memory_layout,
627-
const std::vector<int64_t>& axis_mapping,
653+
const std::vector<int64_t>& axis_map,
628654
const std::vector<int64_t>& padded_sizes,
629655
const vkapi::ScalarType dtype,
630656
const bool allocate_memory)
631657
: context_(context),
632658
storage_type_{storage_type},
633-
image_extents_(calculate_image_extents(
634-
padded_sizes,
635-
axis_mapping,
636-
gpu_memory_layout)),
659+
image_extents_(
660+
calculate_image_extents(padded_sizes, axis_map, gpu_memory_layout)),
637661
buffer_length_{utils::multiply_integers(padded_sizes)},
638662
buffer_offset_{0},
639663
image_(allocate_image(
@@ -746,7 +770,7 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
746770

747771
void vTensorStorage::discard_and_reallocate(
748772
const std::vector<int64_t>& padded_sizes,
749-
const std::vector<int64_t>& axis_mapping,
773+
const std::vector<int64_t>& axis_map,
750774
const utils::GPUMemoryLayout gpu_memory_layout,
751775
const vkapi::ScalarType dtype) {
752776
const bool image_owns_memory = image_.owns_memory();
@@ -755,7 +779,7 @@ void vTensorStorage::discard_and_reallocate(
755779
flush();
756780

757781
image_extents_ =
758-
calculate_image_extents(padded_sizes, axis_mapping, gpu_memory_layout);
782+
calculate_image_extents(padded_sizes, axis_map, gpu_memory_layout);
759783
image_ = allocate_image(
760784
context_,
761785
image_extents_,

0 commit comments

Comments
 (0)