Skip to content

Commit d5a6c3a

Browse files
committed
[ET-VK] Introduce axis mapping for no-copy permute of texture-backed tensors
Pull Request resolved: #5092 ## Context This diff introduces the `axis_mapping` field for `vTensors`, which can be used to implement no-copy permutes. The idea behind the axis mapping is that it is somewhat analogous to dim order for texture backed tensors. The axis mapping is normalized to 4 dimensions, similar to padded sizes. The first 3 elements indicates which of the (X,Y,Z) image texture axes the width, height, and channels dim of the tensor maps to. The final element indicates the WHCN index of the tensor dimension along which batches will be concatenated. The benefit of introducing axis mapping is twofold: 1. Permutes can be performed without any data copying by re-using a texture but updating the axis mapping. 2. Allows the memory layout of texture backed tensors to be more flexible, and optimize for performance or memory footprint by using unconventional axis mappings. Regarding the second point, we have found that adding length to a texture's Z axis is more costly than adding length to the texture's X or Y axes. Similarly, we have found that reading along the Z axis yeilds slightly lower throughput than reading along the X or Y axes. By introducing axis mapping, we can map the largest dimension to a texture's X axis instead of mapping it to the most intuitive texture axis (i.e. channels to Z axis). This can save a lot of texture memory and potentially improve compute shader latency as well. However, the pre-requisite of using texture mapping heavily is that the overhead introduced in calculating tensor indices and texture positions does not significantly increase compute shader latency. The impact of this will be investigated and shown in the following diffs. Note that this diff only introduces the `axis_mapping` field; ghstack-source-id: 241282077 @exported-using-ghexport Differential Revision: [D62210118](https://our.internmc.facebook.com/intern/diff/D62210118/)
1 parent cdb5438 commit d5a6c3a

File tree

4 files changed

+167
-41
lines changed

4 files changed

+167
-41
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,42 @@ std::vector<int64_t> calculate_strides(
8080
return strides;
8181
}
8282

83+
/*
84+
* Axis mapping is somewhat analogous to strides for texture backed tensors.
85+
*
86+
* The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
87+
* The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
88+
* axis that corresponds to the width, height, and channels dimension of the
89+
* tensor. Thus the axis mapping can be considered to be in WHCN dimension
90+
* order.
91+
*
92+
* The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
93+
* dimension along which batches will be concatenated. This dimension can be
94+
* referred to as the "inner dimension" To determine which image texture axis is
95+
* used for the concatenation, a double lookup will need to be performed
96+
* (axis_mapping.at(axis_mapping.at(3))).
97+
*
98+
* The reason for strucuring axis mapping this way is because for the batch dim,
99+
* two things need to be easily derived:
100+
*
101+
* 1. The dim idx of the inner dimension, so that the size of the inner
102+
* dimension can be easily determined.
103+
* 2. The texture axis used to concatenate batches
104+
*
105+
* By storing the dim index of the inner dimension instead of the texture axis
106+
* it maps to, both pieces of information are readily available.
107+
*
108+
* The axis mapping allows for permuted views of texture-backed tensors.
109+
*/
110+
std::vector<int64_t> default_axis_mapping() {
111+
// Currently, all compute shaders have an assumption that the channels dim is
112+
// used to combine with the batch dim of a tensor. However, once dim mapping
113+
// is integrated into the tensor indexing logic for each compute shader, we
114+
// can be more flexible with mapping the batch dim to different texture axes
115+
// in order to improve performance or memory footprint.
116+
return {0, 1, 2, 2};
117+
}
118+
83119
bool dim_order_is_valid(const std::vector<int64_t>& dim_order) {
84120
int64_t sum = 0;
85121
for (size_t i = 0; i < dim_order.size(); ++i) {
@@ -137,30 +173,44 @@ std::vector<int64_t> calculate_padded_sizes(
137173

138174
utils::uvec3 calculate_image_extents(
139175
const std::vector<int64_t>& padded_sizes,
176+
const std::vector<int64_t>& axis_mapping,
140177
const utils::GPUMemoryLayout memory_layout) {
141178
VK_CHECK_COND(padded_sizes.size() == 4);
179+
VK_CHECK_COND(axis_mapping.size() == 4);
180+
181+
utils::uvec3 extents({1, 1, 1});
182+
// First three elements of axis_mapping indicate which (X,Y,Z) image axis the
183+
// width, height, and channels dim of the tensor maps to.
184+
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
185+
const int64_t axis = axis_mapping.at(whcn_dim);
186+
const int64_t dim = padded_sizes.size() - 1 - whcn_dim;
187+
extents[axis] = utils::safe_downcast<uint32_t>(padded_sizes.at(dim));
188+
}
142189

143-
uint32_t N = utils::safe_downcast<uint32_t>(padded_sizes.at(0));
144-
uint32_t C = utils::safe_downcast<uint32_t>(padded_sizes.at(1));
145-
uint32_t H = utils::safe_downcast<uint32_t>(padded_sizes.at(2));
146-
uint32_t W = utils::safe_downcast<uint32_t>(padded_sizes.at(3));
190+
// axis_mapping[3] indicates the WHCN index of the dimension used for batch
191+
// concatenation. Thus a double lookup is required to determine the image axis
192+
// used for batch concatenation.
193+
const int64_t concatted_whcn_dim = axis_mapping.at(3);
194+
const int64_t batch_axis = axis_mapping.at(concatted_whcn_dim);
195+
// Multiply the extents of the batch axis by the batch size.
196+
extents[batch_axis] *= padded_sizes.at(0);
147197

148198
switch (memory_layout) {
149199
case utils::kWidthPacked:
150-
VK_CHECK_COND(W % 4 == 0);
151-
W /= 4;
200+
VK_CHECK_COND(extents[0] % 4 == 0);
201+
extents[0] /= 4;
152202
break;
153203
case utils::kHeightPacked:
154-
VK_CHECK_COND(H % 4 == 0);
155-
H /= 4;
204+
VK_CHECK_COND(extents[1] % 4 == 0);
205+
extents[1] /= 4;
156206
break;
157207
case utils::kChannelsPacked:
158-
VK_CHECK_COND(C % 4 == 0);
159-
C /= 4;
208+
VK_CHECK_COND(extents[2] % 4 == 0);
209+
extents[2] /= 4;
160210
break;
161211
}
162212

163-
return {W, H, C * N};
213+
return extents;
164214
}
165215

166216
//
@@ -176,9 +226,10 @@ vTensor::vTensor(
176226
const bool allocate_memory)
177227
: dtype_(dtype),
178228
memory_layout_(memory_layout),
179-
// Calculate tensor size metadata
229+
// Calculate tensor metadata
180230
sizes_(sizes.begin(), sizes.end()),
181231
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232+
axis_mapping_(default_axis_mapping()),
182233
strides_(calculate_strides(sizes, dim_order_)),
183234
numel_(utils::multiply_integers(sizes_)),
184235
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -189,12 +240,14 @@ vTensor::vTensor(
189240
sizes_uniform_(),
190241
strides_uniform_(),
191242
numel_uniform_(),
243+
axis_mapping_uniform_(),
192244
texture_limits_uniform_(),
193245
// Construct Tensor storage
194246
storage_(
195247
context,
196248
storage_type,
197249
memory_layout_,
250+
axis_mapping_,
198251
padded_sizes_,
199252
dtype_,
200253
allocate_memory) {
@@ -222,6 +275,7 @@ vTensor::vTensor(const vTensor& other)
222275
// Copy tensor size metadata
223276
sizes_(other.sizes_.begin(), other.sizes_.end()),
224277
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278+
axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
225279
strides_(other.strides_.begin(), other.strides_.end()),
226280
numel_(other.numel_),
227281
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
@@ -234,6 +288,7 @@ vTensor::vTensor(const vTensor& other)
234288
sizes_uniform_(),
235289
strides_uniform_(),
236290
numel_uniform_(),
291+
axis_mapping_uniform_(),
237292
texture_limits_uniform_(),
238293
// Copy Tensor storage
239294
storage_(other.storage_) {}
@@ -248,6 +303,7 @@ vTensor::vTensor(
248303
// Copy tensor size metadata
249304
sizes_(sizes.begin(), sizes.end()),
250305
dim_order_(dim_order.begin(), dim_order.end()),
306+
axis_mapping_(default_axis_mapping()),
251307
strides_(calculate_strides(sizes_, dim_order_)),
252308
numel_(utils::multiply_integers(sizes_)),
253309
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -258,6 +314,7 @@ vTensor::vTensor(
258314
sizes_uniform_(),
259315
strides_uniform_(),
260316
numel_uniform_(),
317+
axis_mapping_uniform_(),
261318
texture_limits_uniform_(),
262319
// Copy Tensor storage
263320
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
@@ -315,6 +372,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
315372
return vkapi::BufferBindInfo(strides_uniform_.buffer());
316373
}
317374

375+
const vkapi::BufferBindInfo vTensor::axis_mapping_ubo() {
376+
if (!axis_mapping_uniform_.buffer()) {
377+
axis_mapping_uniform_ =
378+
ParamsBuffer(storage_.context_, utils::make_ivec4(axis_mapping_));
379+
}
380+
return vkapi::BufferBindInfo(axis_mapping_uniform_.buffer());
381+
}
382+
318383
const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
319384
if (!texture_limits_uniform_.buffer()) {
320385
texture_limits_uniform_ = ParamsBuffer(storage_.context_, texture_limits_);
@@ -376,11 +441,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
376441
}
377442
}
378443

379-
void vTensor::update_metadata(
380-
const std::vector<int64_t>& new_sizes,
381-
const std::vector<int64_t>& new_dim_order) {
382-
sizes_ = new_sizes;
383-
dim_order_ = new_dim_order;
444+
void vTensor::update_metadata() {
384445
strides_ = calculate_strides(sizes_, dim_order_);
385446
// Only update the memory layout for buffer-backed tensors. Strides are
386447
// meaningless for texture-backed tensors and do not impact the memory layout.
@@ -396,7 +457,7 @@ void vTensor::update_metadata(
396457
// Calculate the extents of the image texture that would have been required
397458
// for a tensor of the new sizes.
398459
utils::uvec3 virtual_extents =
399-
calculate_image_extents(padded_sizes_, memory_layout_);
460+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
400461

401462
// Update the texture limits to reflect the new virtual extents.
402463
texture_limits_.limits = utils::ivec3{
@@ -407,23 +468,26 @@ void vTensor::update_metadata(
407468
if (sizes_uniform_.buffer()) {
408469
sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
409470
}
410-
if (texture_limits_uniform_.buffer()) {
411-
texture_limits_uniform_.update(texture_limits_);
412-
}
413471
if (strides_uniform_.buffer()) {
414472
strides_uniform_.update(utils::make_whcn_ivec4(unsqueezed_strides_));
415473
}
416474
if (numel_uniform_.buffer()) {
417475
numel_uniform_.update(numel_);
418476
}
477+
if (axis_mapping_uniform_.buffer()) {
478+
axis_mapping_uniform_.update(utils::make_ivec4(axis_mapping_));
479+
}
480+
if (texture_limits_uniform_.buffer()) {
481+
texture_limits_uniform_.update(texture_limits_);
482+
}
419483
}
420484

421485
void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
422486
if (storage_type() != utils::kBuffer) {
423487
// For texture storage check that the current texture is large enough for
424488
// the new sizes of the tensor.
425489
utils::uvec3 virtual_extents =
426-
calculate_image_extents(padded_sizes_, memory_layout_);
490+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
427491

428492
bool valid_resize = virtual_extents[0] <= image_extents()[0];
429493
valid_resize = valid_resize && virtual_extents[1] <= image_extents()[1];
@@ -454,7 +518,9 @@ void vTensor::virtual_reconfigure(
454518
VK_CHECK_COND(dim_order_is_valid(new_dim_order));
455519

456520
check_sizes(new_sizes);
457-
update_metadata(new_sizes, new_dim_order);
521+
sizes_ = new_sizes;
522+
dim_order_ = new_dim_order;
523+
update_metadata();
458524
}
459525

460526
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
@@ -463,13 +529,16 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
463529
"new sizes cannot modify the dimensionality of the tensor ");
464530

465531
check_sizes(new_sizes);
466-
update_metadata(new_sizes, dim_order_);
532+
sizes_ = new_sizes;
533+
update_metadata();
467534
}
468535

469536
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
470-
update_metadata(new_sizes, dim_order_);
537+
sizes_ = new_sizes;
538+
update_metadata();
471539
storage_.discard_and_reallocate(
472540
calculate_padded_sizes(new_sizes, memory_layout_),
541+
axis_mapping_,
473542
memory_layout_,
474543
dtype_);
475544
}
@@ -547,12 +616,16 @@ vTensorStorage::vTensorStorage(
547616
Context* const context,
548617
const utils::StorageType storage_type,
549618
const utils::GPUMemoryLayout gpu_memory_layout,
619+
const std::vector<int64_t>& axis_mapping,
550620
const std::vector<int64_t>& padded_sizes,
551621
const vkapi::ScalarType dtype,
552622
const bool allocate_memory)
553623
: context_(context),
554624
storage_type_{storage_type},
555-
image_extents_(calculate_image_extents(padded_sizes, gpu_memory_layout)),
625+
image_extents_(calculate_image_extents(
626+
padded_sizes,
627+
axis_mapping,
628+
gpu_memory_layout)),
556629
buffer_length_{utils::multiply_integers(padded_sizes)},
557630
buffer_offset_{0},
558631
image_(allocate_image(
@@ -665,14 +738,16 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
665738

666739
void vTensorStorage::discard_and_reallocate(
667740
const std::vector<int64_t>& padded_sizes,
741+
const std::vector<int64_t>& axis_mapping,
668742
const utils::GPUMemoryLayout gpu_memory_layout,
669743
const vkapi::ScalarType dtype) {
670744
const bool image_owns_memory = image_.owns_memory();
671745
const bool buffer_owns_memory = buffer_.owns_memory();
672746

673747
flush();
674748

675-
image_extents_ = calculate_image_extents(padded_sizes, gpu_memory_layout);
749+
image_extents_ =
750+
calculate_image_extents(padded_sizes, axis_mapping, gpu_memory_layout);
676751
image_ = allocate_image(
677752
context_,
678753
image_extents_,

0 commit comments

Comments
 (0)