Skip to content

Commit cab29ea

Browse files
authored
[ET-VK] Introduce axis mapping for no-copy permute of texture-backed tensors
Differential Revision: D62210118 Pull Request resolved: #5092
1 parent 5f4a811 commit cab29ea

File tree

4 files changed

+167
-41
lines changed

4 files changed

+167
-41
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,42 @@ std::vector<int64_t> calculate_strides(
8080
return strides;
8181
}
8282

83+
/*
84+
* Axis mapping is somewhat analogous to strides for texture backed tensors.
85+
*
86+
* The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
87+
* The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
88+
* axis that corresponds to the width, height, and channels dimension of the
89+
* tensor. Thus the axis mapping can be considered to be in WHCN dimension
90+
* order.
91+
*
92+
* The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
93+
* dimension along which batches will be concatenated. This dimension can be
94+
* referred to as the "inner dimension" To determine which image texture axis is
95+
* used for the concatenation, a double lookup will need to be performed
96+
* (axis_mapping.at(axis_mapping.at(3))).
97+
*
98+
* The reason for strucuring axis mapping this way is because for the batch dim,
99+
* two things need to be easily derived:
100+
*
101+
* 1. The dim idx of the inner dimension, so that the size of the inner
102+
* dimension can be easily determined.
103+
* 2. The texture axis used to concatenate batches
104+
*
105+
* By storing the dim index of the inner dimension instead of the texture axis
106+
* it maps to, both pieces of information are readily available.
107+
*
108+
* The axis mapping allows for permuted views of texture-backed tensors.
109+
*/
110+
std::vector<int64_t> default_axis_mapping() {
111+
// Currently, all compute shaders have an assumption that the channels dim is
112+
// used to combine with the batch dim of a tensor. However, once dim mapping
113+
// is integrated into the tensor indexing logic for each compute shader, we
114+
// can be more flexible with mapping the batch dim to different texture axes
115+
// in order to improve performance or memory footprint.
116+
return {0, 1, 2, 2};
117+
}
118+
83119
bool dim_order_is_valid(const std::vector<int64_t>& dim_order) {
84120
int64_t sum = 0;
85121
for (size_t i = 0; i < dim_order.size(); ++i) {
@@ -137,30 +173,44 @@ std::vector<int64_t> calculate_padded_sizes(
137173

138174
utils::uvec3 calculate_image_extents(
139175
const std::vector<int64_t>& padded_sizes,
176+
const std::vector<int64_t>& axis_mapping,
140177
const utils::GPUMemoryLayout memory_layout) {
141178
VK_CHECK_COND(padded_sizes.size() == 4);
179+
VK_CHECK_COND(axis_mapping.size() == 4);
180+
181+
utils::uvec3 extents({1, 1, 1});
182+
// First three elements of axis_mapping indicate which (X,Y,Z) image axis the
183+
// width, height, and channels dim of the tensor maps to.
184+
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
185+
const int64_t axis = axis_mapping.at(whcn_dim);
186+
const int64_t dim = padded_sizes.size() - 1 - whcn_dim;
187+
extents[axis] = utils::safe_downcast<uint32_t>(padded_sizes.at(dim));
188+
}
142189

143-
uint32_t N = utils::safe_downcast<uint32_t>(padded_sizes.at(0));
144-
uint32_t C = utils::safe_downcast<uint32_t>(padded_sizes.at(1));
145-
uint32_t H = utils::safe_downcast<uint32_t>(padded_sizes.at(2));
146-
uint32_t W = utils::safe_downcast<uint32_t>(padded_sizes.at(3));
190+
// axis_mapping[3] indicates the WHCN index of the dimension used for batch
191+
// concatenation. Thus a double lookup is required to determine the image axis
192+
// used for batch concatenation.
193+
const int64_t concatted_whcn_dim = axis_mapping.at(3);
194+
const int64_t batch_axis = axis_mapping.at(concatted_whcn_dim);
195+
// Multiply the extents of the batch axis by the batch size.
196+
extents[batch_axis] *= padded_sizes.at(0);
147197

148198
switch (memory_layout) {
149199
case utils::kWidthPacked:
150-
VK_CHECK_COND(W % 4 == 0);
151-
W /= 4;
200+
VK_CHECK_COND(extents[0] % 4 == 0);
201+
extents[0] /= 4;
152202
break;
153203
case utils::kHeightPacked:
154-
VK_CHECK_COND(H % 4 == 0);
155-
H /= 4;
204+
VK_CHECK_COND(extents[1] % 4 == 0);
205+
extents[1] /= 4;
156206
break;
157207
case utils::kChannelsPacked:
158-
VK_CHECK_COND(C % 4 == 0);
159-
C /= 4;
208+
VK_CHECK_COND(extents[2] % 4 == 0);
209+
extents[2] /= 4;
160210
break;
161211
}
162212

163-
return {W, H, C * N};
213+
return extents;
164214
}
165215

166216
//
@@ -176,9 +226,10 @@ vTensor::vTensor(
176226
const bool allocate_memory)
177227
: dtype_(dtype),
178228
memory_layout_(memory_layout),
179-
// Calculate tensor size metadata
229+
// Calculate tensor metadata
180230
sizes_(sizes.begin(), sizes.end()),
181231
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232+
axis_mapping_(default_axis_mapping()),
182233
strides_(calculate_strides(sizes, dim_order_)),
183234
numel_(utils::multiply_integers(sizes_)),
184235
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -189,12 +240,14 @@ vTensor::vTensor(
189240
sizes_uniform_(),
190241
strides_uniform_(),
191242
numel_uniform_(),
243+
axis_mapping_uniform_(),
192244
texture_limits_uniform_(),
193245
// Construct Tensor storage
194246
storage_(
195247
context,
196248
storage_type,
197249
memory_layout_,
250+
axis_mapping_,
198251
padded_sizes_,
199252
dtype_,
200253
allocate_memory) {
@@ -222,6 +275,7 @@ vTensor::vTensor(const vTensor& other)
222275
// Copy tensor size metadata
223276
sizes_(other.sizes_.begin(), other.sizes_.end()),
224277
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278+
axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
225279
strides_(other.strides_.begin(), other.strides_.end()),
226280
numel_(other.numel_),
227281
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
@@ -234,6 +288,7 @@ vTensor::vTensor(const vTensor& other)
234288
sizes_uniform_(),
235289
strides_uniform_(),
236290
numel_uniform_(),
291+
axis_mapping_uniform_(),
237292
texture_limits_uniform_(),
238293
// Copy Tensor storage
239294
storage_(other.storage_) {}
@@ -248,6 +303,7 @@ vTensor::vTensor(
248303
// Copy tensor size metadata
249304
sizes_(sizes.begin(), sizes.end()),
250305
dim_order_(dim_order.begin(), dim_order.end()),
306+
axis_mapping_(default_axis_mapping()),
251307
strides_(calculate_strides(sizes_, dim_order_)),
252308
numel_(utils::multiply_integers(sizes_)),
253309
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -258,6 +314,7 @@ vTensor::vTensor(
258314
sizes_uniform_(),
259315
strides_uniform_(),
260316
numel_uniform_(),
317+
axis_mapping_uniform_(),
261318
texture_limits_uniform_(),
262319
// Copy Tensor storage
263320
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
@@ -315,6 +372,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
315372
return vkapi::BufferBindInfo(strides_uniform_.buffer());
316373
}
317374

375+
const vkapi::BufferBindInfo vTensor::axis_mapping_ubo() {
376+
if (!axis_mapping_uniform_.buffer()) {
377+
axis_mapping_uniform_ =
378+
ParamsBuffer(storage_.context_, utils::make_ivec4(axis_mapping_));
379+
}
380+
return vkapi::BufferBindInfo(axis_mapping_uniform_.buffer());
381+
}
382+
318383
const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
319384
if (!texture_limits_uniform_.buffer()) {
320385
texture_limits_uniform_ = ParamsBuffer(storage_.context_, texture_limits_);
@@ -376,11 +441,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
376441
}
377442
}
378443

379-
void vTensor::update_metadata(
380-
const std::vector<int64_t>& new_sizes,
381-
const std::vector<int64_t>& new_dim_order) {
382-
sizes_ = new_sizes;
383-
dim_order_ = new_dim_order;
444+
void vTensor::update_metadata() {
384445
strides_ = calculate_strides(sizes_, dim_order_);
385446
// Only update the memory layout for buffer-backed tensors. Strides are
386447
// meaningless for texture-backed tensors and do not impact the memory layout.
@@ -396,7 +457,7 @@ void vTensor::update_metadata(
396457
// Calculate the extents of the image texture that would have been required
397458
// for a tensor of the new sizes.
398459
utils::uvec3 virtual_extents =
399-
calculate_image_extents(padded_sizes_, memory_layout_);
460+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
400461

401462
// Update the texture limits to reflect the new virtual extents.
402463
texture_limits_.limits = utils::ivec3{
@@ -407,23 +468,26 @@ void vTensor::update_metadata(
407468
if (sizes_uniform_.buffer()) {
408469
sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
409470
}
410-
if (texture_limits_uniform_.buffer()) {
411-
texture_limits_uniform_.update(texture_limits_);
412-
}
413471
if (strides_uniform_.buffer()) {
414472
strides_uniform_.update(utils::make_whcn_ivec4(unsqueezed_strides_));
415473
}
416474
if (numel_uniform_.buffer()) {
417475
numel_uniform_.update(numel_);
418476
}
477+
if (axis_mapping_uniform_.buffer()) {
478+
axis_mapping_uniform_.update(utils::make_ivec4(axis_mapping_));
479+
}
480+
if (texture_limits_uniform_.buffer()) {
481+
texture_limits_uniform_.update(texture_limits_);
482+
}
419483
}
420484

421485
void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
422486
if (storage_type() != utils::kBuffer) {
423487
// For texture storage check that the current texture is large enough for
424488
// the new sizes of the tensor.
425489
utils::uvec3 virtual_extents =
426-
calculate_image_extents(padded_sizes_, memory_layout_);
490+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
427491

428492
bool valid_resize = virtual_extents[0] <= image_extents()[0];
429493
valid_resize = valid_resize && virtual_extents[1] <= image_extents()[1];
@@ -454,7 +518,9 @@ void vTensor::virtual_reconfigure(
454518
VK_CHECK_COND(dim_order_is_valid(new_dim_order));
455519

456520
check_sizes(new_sizes);
457-
update_metadata(new_sizes, new_dim_order);
521+
sizes_ = new_sizes;
522+
dim_order_ = new_dim_order;
523+
update_metadata();
458524
}
459525

460526
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
@@ -463,13 +529,16 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
463529
"new sizes cannot modify the dimensionality of the tensor ");
464530

465531
check_sizes(new_sizes);
466-
update_metadata(new_sizes, dim_order_);
532+
sizes_ = new_sizes;
533+
update_metadata();
467534
}
468535

469536
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
470-
update_metadata(new_sizes, dim_order_);
537+
sizes_ = new_sizes;
538+
update_metadata();
471539
storage_.discard_and_reallocate(
472540
calculate_padded_sizes(new_sizes, memory_layout_),
541+
axis_mapping_,
473542
memory_layout_,
474543
dtype_);
475544
}
@@ -547,12 +616,16 @@ vTensorStorage::vTensorStorage(
547616
Context* const context,
548617
const utils::StorageType storage_type,
549618
const utils::GPUMemoryLayout gpu_memory_layout,
619+
const std::vector<int64_t>& axis_mapping,
550620
const std::vector<int64_t>& padded_sizes,
551621
const vkapi::ScalarType dtype,
552622
const bool allocate_memory)
553623
: context_(context),
554624
storage_type_{storage_type},
555-
image_extents_(calculate_image_extents(padded_sizes, gpu_memory_layout)),
625+
image_extents_(calculate_image_extents(
626+
padded_sizes,
627+
axis_mapping,
628+
gpu_memory_layout)),
556629
buffer_length_{utils::multiply_integers(padded_sizes)},
557630
buffer_offset_{0},
558631
image_(allocate_image(
@@ -665,14 +738,16 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
665738

666739
void vTensorStorage::discard_and_reallocate(
667740
const std::vector<int64_t>& padded_sizes,
741+
const std::vector<int64_t>& axis_mapping,
668742
const utils::GPUMemoryLayout gpu_memory_layout,
669743
const vkapi::ScalarType dtype) {
670744
const bool image_owns_memory = image_.owns_memory();
671745
const bool buffer_owns_memory = buffer_.owns_memory();
672746

673747
flush();
674748

675-
image_extents_ = calculate_image_extents(padded_sizes, gpu_memory_layout);
749+
image_extents_ =
750+
calculate_image_extents(padded_sizes, axis_mapping, gpu_memory_layout);
676751
image_ = allocate_image(
677752
context_,
678753
image_extents_,

0 commit comments

Comments
 (0)