Skip to content

Commit 5e2302d

Browse files
committed
[ET-VK] Introduce axis mapping for no-copy permute of texture-backed tensors
## Context This diff introduces the `axis_mapping` field for `vTensors`, which can be used to implement no-copy permutes. The idea behind the axis mapping is that it is somewhat analogous to dim order for texture backed tensors. The axis mapping is normalized to 4 dimensions, similar to padded sizes. The first 3 elements indicates which of the (X,Y,Z) image texture axes the width, height, and channels dim of the tensor maps to. The final element indicates the WHCN index of the tensor dimension along which batches will be concatenated. The benefit of introducing axis mapping is twofold: 1. Permutes can be performed without any data copying by re-using a texture but updating the axis mapping. 2. Allows the memory layout of texture backed tensors to be more flexible, and optimize for performance or memory footprint by using unconventional axis mappings. Regarding the second point, we have found that adding length to a texture's Z axis is more costly than adding length to the texture's X or Y axes. Similarly, we have found that reading along the Z axis yeilds slightly lower throughput than reading along the X or Y axes. By introducing axis mapping, we can map the largest dimension to a texture's X axis instead of mapping it to the most intuitive texture axis (i.e. channels to Z axis). This can save a lot of texture memory and potentially improve compute shader latency as well. However, the pre-requisite of using texture mapping heavily is that the overhead introduced in calculating tensor indices and texture positions does not significantly increase compute shader latency. The impact of this will be investigated and shown in the following diffs. Note that this diff only introduces the `axis_mapping` field; Differential Revision: [D62210118](https://our.internmc.facebook.com/intern/diff/D62210118/) ghstack-source-id: 241066641 Pull Request resolved: #5092
1 parent 2e3f62f commit 5e2302d

File tree

3 files changed

+148
-39
lines changed

3 files changed

+148
-39
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,31 @@ std::vector<int64_t> calculate_strides(
8080
return strides;
8181
}
8282

83+
/*
84+
* Axis mapping is somewhat analogous to strides for texture backed tensors.
85+
*
86+
* The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
87+
* The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
88+
* axis that corresponds to the width, height, and channels dimension of the
89+
* tensor. Thus the axis mapping can be considered to be in WHCN dimension
90+
* order.
91+
*
92+
* The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
93+
* dimension along which batches will be concatenated. To determine which image
94+
* texture axis is used for the concatenation, a double lookup will need to be
95+
* performed (axis_mapping.at(axis_mapping.at(3))).
96+
*
97+
* The axis mapping allows for permuted views of texture-backed tensors.
98+
*/
99+
std::vector<int64_t> default_axis_mapping() {
100+
// Currently, all compute shaders have an assumption that the channels dim is
101+
// used to combine with the batch dim of a tensor. However, once dim mapping
102+
// is integrated into the tensor indexing logic for each compute shader, we
103+
// can be more flexible with mapping the batch dim to different texture axes
104+
// in order to improve performance or memory footprint.
105+
return {0, 1, 2, 2};
106+
}
107+
83108
bool dim_order_is_valid(const std::vector<int64_t>& dim_order) {
84109
int64_t sum = 0;
85110
for (size_t i = 0; i < dim_order.size(); ++i) {
@@ -137,30 +162,44 @@ std::vector<int64_t> calculate_padded_sizes(
137162

138163
utils::uvec3 calculate_image_extents(
139164
const std::vector<int64_t>& padded_sizes,
165+
const std::vector<int64_t>& axis_mapping,
140166
const utils::GPUMemoryLayout memory_layout) {
141167
VK_CHECK_COND(padded_sizes.size() == 4);
168+
VK_CHECK_COND(axis_mapping.size() == 4);
142169

143-
uint32_t N = utils::safe_downcast<uint32_t>(padded_sizes.at(0));
144-
uint32_t C = utils::safe_downcast<uint32_t>(padded_sizes.at(1));
145-
uint32_t H = utils::safe_downcast<uint32_t>(padded_sizes.at(2));
146-
uint32_t W = utils::safe_downcast<uint32_t>(padded_sizes.at(3));
170+
utils::uvec3 extents({1, 1, 1});
171+
// First three elements of axis_mapping indicate which (X,Y,Z) image axis the
172+
// width, height, and channels dim of the tensor maps to.
173+
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
174+
const int64_t axis = axis_mapping.at(whcn_dim);
175+
const int64_t dim = padded_sizes.size() - 1 - whcn_dim;
176+
extents[axis] = utils::safe_downcast<uint32_t>(padded_sizes.at(dim));
177+
}
178+
179+
// axis_mapping[3] indicates the WHCN index of the dimension used for batch
180+
// concatenation. Thus a double lookup is required to determine the image axis
181+
// used for batch concatenation.
182+
const int64_t concatted_whcn_dim = axis_mapping.at(3);
183+
const int64_t batch_axis = axis_mapping.at(concatted_whcn_dim);
184+
// Multiply the extents of the batch axis by the batch size.
185+
extents[batch_axis] *= padded_sizes.at(0);
147186

148187
switch (memory_layout) {
149188
case utils::kWidthPacked:
150-
VK_CHECK_COND(W % 4 == 0);
151-
W /= 4;
189+
VK_CHECK_COND(extents[0] % 4 == 0);
190+
extents[0] /= 4;
152191
break;
153192
case utils::kHeightPacked:
154-
VK_CHECK_COND(H % 4 == 0);
155-
H /= 4;
193+
VK_CHECK_COND(extents[1] % 4 == 0);
194+
extents[1] /= 4;
156195
break;
157196
case utils::kChannelsPacked:
158-
VK_CHECK_COND(C % 4 == 0);
159-
C /= 4;
197+
VK_CHECK_COND(extents[2] % 4 == 0);
198+
extents[2] /= 4;
160199
break;
161200
}
162201

163-
return {W, H, C * N};
202+
return extents;
164203
}
165204

166205
//
@@ -176,9 +215,10 @@ vTensor::vTensor(
176215
const bool allocate_memory)
177216
: dtype_(dtype),
178217
memory_layout_(memory_layout),
179-
// Calculate tensor size metadata
218+
// Calculate tensor metadata
180219
sizes_(sizes.begin(), sizes.end()),
181220
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
221+
axis_mapping_(default_axis_mapping()),
182222
strides_(calculate_strides(sizes, dim_order_)),
183223
numel_(utils::multiply_integers(sizes_)),
184224
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -189,12 +229,14 @@ vTensor::vTensor(
189229
sizes_uniform_(),
190230
strides_uniform_(),
191231
numel_uniform_(),
232+
axis_mapping_uniform_(),
192233
texture_limits_uniform_(),
193234
// Construct Tensor storage
194235
storage_(
195236
context,
196237
storage_type,
197238
memory_layout_,
239+
axis_mapping_,
198240
padded_sizes_,
199241
dtype_,
200242
allocate_memory) {
@@ -222,6 +264,7 @@ vTensor::vTensor(const vTensor& other)
222264
// Copy tensor size metadata
223265
sizes_(other.sizes_.begin(), other.sizes_.end()),
224266
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
267+
axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
225268
strides_(other.strides_.begin(), other.strides_.end()),
226269
numel_(other.numel_),
227270
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
@@ -234,6 +277,7 @@ vTensor::vTensor(const vTensor& other)
234277
sizes_uniform_(),
235278
strides_uniform_(),
236279
numel_uniform_(),
280+
axis_mapping_uniform_(),
237281
texture_limits_uniform_(),
238282
// Copy Tensor storage
239283
storage_(other.storage_) {}
@@ -248,6 +292,7 @@ vTensor::vTensor(
248292
// Copy tensor size metadata
249293
sizes_(sizes.begin(), sizes.end()),
250294
dim_order_(dim_order.begin(), dim_order.end()),
295+
axis_mapping_(default_axis_mapping()),
251296
strides_(calculate_strides(sizes_, dim_order_)),
252297
numel_(utils::multiply_integers(sizes_)),
253298
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
@@ -258,6 +303,7 @@ vTensor::vTensor(
258303
sizes_uniform_(),
259304
strides_uniform_(),
260305
numel_uniform_(),
306+
axis_mapping_uniform_(),
261307
texture_limits_uniform_(),
262308
// Copy Tensor storage
263309
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
@@ -315,6 +361,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
315361
return vkapi::BufferBindInfo(strides_uniform_.buffer());
316362
}
317363

364+
const vkapi::BufferBindInfo vTensor::axis_mapping_ubo() {
365+
if (!axis_mapping_uniform_.buffer()) {
366+
axis_mapping_uniform_ =
367+
ParamsBuffer(storage_.context_, utils::make_ivec4(axis_mapping_));
368+
}
369+
return vkapi::BufferBindInfo(axis_mapping_uniform_.buffer());
370+
}
371+
318372
const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
319373
if (!texture_limits_uniform_.buffer()) {
320374
texture_limits_uniform_ = ParamsBuffer(storage_.context_, texture_limits_);
@@ -376,11 +430,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
376430
}
377431
}
378432

379-
void vTensor::update_metadata(
380-
const std::vector<int64_t>& new_sizes,
381-
const std::vector<int64_t>& new_dim_order) {
382-
sizes_ = new_sizes;
383-
dim_order_ = new_dim_order;
433+
void vTensor::update_metadata() {
384434
strides_ = calculate_strides(sizes_, dim_order_);
385435
// Only update the memory layout for buffer-backed tensors. Strides are
386436
// meaningless for texture-backed tensors and do not impact the memory layout.
@@ -396,7 +446,7 @@ void vTensor::update_metadata(
396446
// Calculate the extents of the image texture that would have been required
397447
// for a tensor of the new sizes.
398448
utils::uvec3 virtual_extents =
399-
calculate_image_extents(padded_sizes_, memory_layout_);
449+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
400450

401451
// Update the texture limits to reflect the new virtual extents.
402452
texture_limits_.limits = utils::ivec3{
@@ -407,23 +457,26 @@ void vTensor::update_metadata(
407457
if (sizes_uniform_.buffer()) {
408458
sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
409459
}
410-
if (texture_limits_uniform_.buffer()) {
411-
texture_limits_uniform_.update(texture_limits_);
412-
}
413460
if (strides_uniform_.buffer()) {
414461
strides_uniform_.update(utils::make_whcn_ivec4(unsqueezed_strides_));
415462
}
416463
if (numel_uniform_.buffer()) {
417464
numel_uniform_.update(numel_);
418465
}
466+
if (axis_mapping_uniform_.buffer()) {
467+
axis_mapping_uniform_.update(utils::make_ivec4(axis_mapping_));
468+
}
469+
if (texture_limits_uniform_.buffer()) {
470+
texture_limits_uniform_.update(texture_limits_);
471+
}
419472
}
420473

421474
void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
422475
if (storage_type() != utils::kBuffer) {
423476
// For texture storage check that the current texture is large enough for
424477
// the new sizes of the tensor.
425478
utils::uvec3 virtual_extents =
426-
calculate_image_extents(padded_sizes_, memory_layout_);
479+
calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_);
427480

428481
bool valid_resize = virtual_extents[0] <= image_extents()[0];
429482
valid_resize = valid_resize && virtual_extents[1] <= image_extents()[1];
@@ -454,7 +507,9 @@ void vTensor::virtual_reconfigure(
454507
VK_CHECK_COND(dim_order_is_valid(new_dim_order));
455508

456509
check_sizes(new_sizes);
457-
update_metadata(new_sizes, new_dim_order);
510+
sizes_ = new_sizes;
511+
dim_order_ = new_dim_order;
512+
update_metadata();
458513
}
459514

460515
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
@@ -463,13 +518,16 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
463518
"new sizes cannot modify the dimensionality of the tensor ");
464519

465520
check_sizes(new_sizes);
466-
update_metadata(new_sizes, dim_order_);
521+
sizes_ = new_sizes;
522+
update_metadata();
467523
}
468524

469525
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
470-
update_metadata(new_sizes, dim_order_);
526+
sizes_ = new_sizes;
527+
update_metadata();
471528
storage_.discard_and_reallocate(
472529
calculate_padded_sizes(new_sizes, memory_layout_),
530+
axis_mapping_,
473531
memory_layout_,
474532
dtype_);
475533
}
@@ -547,12 +605,16 @@ vTensorStorage::vTensorStorage(
547605
Context* const context,
548606
const utils::StorageType storage_type,
549607
const utils::GPUMemoryLayout gpu_memory_layout,
608+
const std::vector<int64_t>& axis_mapping,
550609
const std::vector<int64_t>& padded_sizes,
551610
const vkapi::ScalarType dtype,
552611
const bool allocate_memory)
553612
: context_(context),
554613
storage_type_{storage_type},
555-
image_extents_(calculate_image_extents(padded_sizes, gpu_memory_layout)),
614+
image_extents_(calculate_image_extents(
615+
padded_sizes,
616+
axis_mapping,
617+
gpu_memory_layout)),
556618
buffer_length_{utils::multiply_integers(padded_sizes)},
557619
buffer_offset_{0},
558620
image_(allocate_image(
@@ -665,14 +727,16 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
665727

666728
void vTensorStorage::discard_and_reallocate(
667729
const std::vector<int64_t>& padded_sizes,
730+
const std::vector<int64_t>& axis_mapping,
668731
const utils::GPUMemoryLayout gpu_memory_layout,
669732
const vkapi::ScalarType dtype) {
670733
const bool image_owns_memory = image_.owns_memory();
671734
const bool buffer_owns_memory = buffer_.owns_memory();
672735

673736
flush();
674737

675-
image_extents_ = calculate_image_extents(padded_sizes, gpu_memory_layout);
738+
image_extents_ =
739+
calculate_image_extents(padded_sizes, axis_mapping, gpu_memory_layout);
676740
image_ = allocate_image(
677741
context_,
678742
image_extents_,

0 commit comments

Comments
 (0)