Skip to content

Commit b2cdfcf

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Optimized axis map
Summary: Add a flag to optimize the axis map layout to be in descending order of axis size. The default is still the same. Differential Revision: D67692960
1 parent debafbe commit b2cdfcf

File tree

7 files changed

+76
-33
lines changed

7 files changed

+76
-33
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,25 @@ std::vector<int64_t> calculate_strides(
9999
*
100100
* The axis mapping allows for permuted views of texture-backed tensors.
101101
*/
102-
std::vector<int64_t> default_axis_map() {
103-
// Currently, all compute shaders have an assumption that the channels dim is
104-
// used to combine with the batch dim of a tensor. However, once dim mapping
105-
// is integrated into the tensor indexing logic for each compute shader, we
106-
// can be more flexible with mapping the batch dim to different texture axes
107-
// in order to improve performance or memory footprint.
102+
std::vector<int64_t> calculate_axis_map(const std::vector<int64_t>& sizes, utils::AxisMapLayout axis_map_layout){
103+
if(axis_map_layout == utils::AxisMapLayout::OPTIMIZED){
104+
std::vector<int64_t> axis_map(sizes.size() + 1);
105+
std::iota(axis_map.begin(), axis_map.end() - 1, 0);
106+
107+
std::stable_sort(axis_map.begin(), axis_map.end() - 1,
108+
[&sizes](size_t i1, size_t i2) {return sizes[i1] < sizes[i2];});
109+
110+
// Find the index of the channel dimension
111+
for(size_t i = 0; i < axis_map.size() - 1; ++i){
112+
if(sizes[axis_map[i]] == 2){
113+
axis_map.back() = i;
114+
break;
115+
}
116+
}
117+
118+
return axis_map;
119+
}
120+
// default
108121
return {0, 1, 2, 2};
109122
}
110123

@@ -439,13 +452,14 @@ vTensor::vTensor(
439452
const vkapi::ScalarType dtype,
440453
const utils::StorageType storage_type,
441454
const utils::GPUMemoryLayout memory_layout,
442-
const bool allocate_memory)
455+
const bool allocate_memory,
456+
const utils::AxisMapLayout axis_map_layout)
443457
: dtype_(dtype),
444458
// Calculate tensor metadata
445459
sizes_(sizes.begin(), sizes.end()),
446460
packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)),
447461
dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)),
448-
axis_map_(default_axis_map()),
462+
axis_map_(calculate_axis_map(sizes_, axis_map_layout)),
449463
strides_(calculate_strides(sizes, dim_order_)),
450464
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
451465
unsqueezed_strides_{
@@ -484,13 +498,14 @@ vTensor::vTensor(
484498
vTensor::vTensor(
485499
Context* context,
486500
const vkapi::VulkanImage& image,
487-
const utils::GPUMemoryLayout memory_layout)
501+
const utils::GPUMemoryLayout memory_layout,
502+
const utils::AxisMapLayout axis_map_layout)
488503
: dtype_(vkapi::element_scalartype(image.format())),
489504
// Calculate tensor metadata
490505
sizes_(calculate_sizes(image, memory_layout)),
491506
packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)),
492507
dim_order_(),
493-
axis_map_(default_axis_map()),
508+
axis_map_(calculate_axis_map(sizes_, axis_map_layout)),
494509
strides_(),
495510
padded_sizes_(calculate_padded_sizes(sizes_, packed_dim_)),
496511
unsqueezed_strides_(),

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,16 @@ class vTensor final {
183183
const vkapi::ScalarType dtype,
184184
const utils::StorageType storage_type = utils::kTexture3D,
185185
const utils::GPUMemoryLayout memory_layout = utils::kChannelsPacked,
186-
const bool allocate_memory = true);
186+
const bool allocate_memory = true,
187+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
187188

188189
vTensor(const vTensor& other) = delete;
189190

190191
explicit vTensor(
191192
Context* context,
192193
const vkapi::VulkanImage& image,
193-
const utils::GPUMemoryLayout memory_layout = utils::kChannelsPacked);
194+
const utils::GPUMemoryLayout memory_layout = utils::kChannelsPacked,
195+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
194196

195197
/*
196198
* This constructor allows for the creation of a vTensor that references the

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,14 @@ ValueRef ComputeGraph::add_tensor(
239239
const vkapi::ScalarType dtype,
240240
const utils::StorageType storage_type,
241241
const utils::GPUMemoryLayout memory_layout,
242-
const int64_t shared_object_idx) {
242+
const int64_t shared_object_idx,
243+
const utils::AxisMapLayout axis_map_layout) {
243244
bool allocate_memory = shared_object_idx < 0;
244245

245246
ValueRef idx(static_cast<int>(values_.size()));
246247
check_no_active_value_ptrs();
247248
values_.emplace_back(api::vTensor(
248-
context(), sizes, dtype, storage_type, memory_layout, allocate_memory));
249+
context(), sizes, dtype, storage_type, memory_layout, allocate_memory, axis_map_layout));
249250

250251
if (!allocate_memory) {
251252
get_shared_object(shared_object_idx).add_user(this, idx);
@@ -257,44 +258,50 @@ ValueRef ComputeGraph::add_tensor(
257258
const std::vector<int64_t>& sizes,
258259
const vkapi::ScalarType dtype,
259260
const utils::StorageType storage_type,
260-
const int64_t shared_object_idx) {
261+
const int64_t shared_object_idx,
262+
const utils::AxisMapLayout axis_map_layout) {
261263
return add_tensor(
262264
sizes,
263265
dtype,
264266
storage_type,
265267
suggested_memory_layout(sizes),
266-
shared_object_idx);
268+
shared_object_idx,
269+
axis_map_layout);
267270
}
268271

269272
ValueRef ComputeGraph::add_tensor(
270273
const std::vector<int64_t>& sizes,
271274
const vkapi::ScalarType dtype,
272275
const utils::GPUMemoryLayout memory_layout,
273-
const int64_t shared_object_idx) {
276+
const int64_t shared_object_idx,
277+
const utils::AxisMapLayout axis_map_layout) {
274278
return add_tensor(
275-
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
279+
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx, axis_map_layout);
276280
}
277281

278282
ValueRef ComputeGraph::add_tensor_like(
279283
const ValueRef idx,
280284
const utils::StorageType storage_type,
281-
const utils::GPUMemoryLayout memory_layout) {
282-
return add_tensor(sizes_of(idx), dtype_of(idx), storage_type, memory_layout);
285+
const utils::GPUMemoryLayout memory_layout,
286+
const utils::AxisMapLayout axis_map_layout) {
287+
return add_tensor(sizes_of(idx), dtype_of(idx), storage_type, memory_layout, -1, axis_map_layout);
283288
}
284289

285290
ValueRef ComputeGraph::add_tensor_like(
286291
const ValueRef idx,
287-
const utils::GPUMemoryLayout memory_layout) {
292+
const utils::GPUMemoryLayout memory_layout,
293+
const utils::AxisMapLayout axis_map_layout) {
288294
return add_tensor(
289-
sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout);
295+
sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout, -1, axis_map_layout);
290296
}
291297

292298
ValueRef ComputeGraph::add_tensor(
293299
const std::vector<int64_t>& sizes,
294300
const vkapi::ScalarType dtype,
295-
const int64_t shared_object_idx) {
301+
const int64_t shared_object_idx,
302+
const utils::AxisMapLayout axis_map_layout) {
296303
return add_tensor(
297-
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
304+
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx, axis_map_layout);
298305
}
299306

300307
ValueRef ComputeGraph::add_tensor(const vkapi::VulkanImage& image) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ class ComputeGraph final {
461461
const vkapi::ScalarType dtype,
462462
const utils::StorageType storage_type,
463463
const utils::GPUMemoryLayout memory_layout,
464-
const int64_t shared_object_idx = -1);
464+
const int64_t shared_object_idx = -1,
465+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
465466

466467
/*
467468
* Add a `api::vTensor` value to the graph with the specified properties. The
@@ -471,7 +472,8 @@ class ComputeGraph final {
471472
const std::vector<int64_t>& sizes,
472473
const vkapi::ScalarType dtype,
473474
const utils::StorageType storage_type,
474-
const int64_t shared_object_idx = -1);
475+
const int64_t shared_object_idx = -1,
476+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
475477

476478
/*
477479
* Add a `api::vTensor` value to the graph with the specified properties. The
@@ -481,7 +483,8 @@ class ComputeGraph final {
481483
const std::vector<int64_t>& sizes,
482484
const vkapi::ScalarType dtype,
483485
const utils::GPUMemoryLayout memory_layout,
484-
const int64_t shared_object_idx = -1);
486+
const int64_t shared_object_idx = -1,
487+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
485488

486489
/*
487490
* Add a `api::vTensor` value to the graph with the specified properties. The
@@ -491,7 +494,8 @@ class ComputeGraph final {
491494
ValueRef add_tensor(
492495
const std::vector<int64_t>& sizes,
493496
const vkapi::ScalarType dtype,
494-
const int64_t shared_object_idx = -1);
497+
const int64_t shared_object_idx = -1,
498+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
495499

496500
/*
497501
* Add a `api::vTensor` value to the graph with the specified image.
@@ -504,15 +508,17 @@ class ComputeGraph final {
504508
ValueRef add_tensor_like(
505509
const ValueRef vref,
506510
const utils::StorageType storage_type,
507-
const utils::GPUMemoryLayout memory_layout);
511+
const utils::GPUMemoryLayout memory_layout,
512+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
508513

509514
/*
510515
* Add a `api::vTensor` value to the graph with the properties of `vref`. The
511516
* suggested storage type will be used to construct the `api::vTensor`.
512517
*/
513518
ValueRef add_tensor_like(
514519
const ValueRef vref,
515-
const utils::GPUMemoryLayout memory_layout);
520+
const utils::GPUMemoryLayout memory_layout,
521+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
516522

517523
/*
518524
* Use the copy constructor of `api::vTensor` to create a "view" of the

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,13 @@ ValueRef prepack_standard(
146146
const ValueRef tensor_data,
147147
const utils::StorageType storage_type,
148148
const utils::GPUMemoryLayout layout,
149-
const bool passthrough) {
149+
const bool passthrough,
150+
const utils::AxisMapLayout axis_map_layout) {
150151
if (passthrough && graph.val_is_tensor(tensor_data)) {
151152
return tensor_data;
152153
}
153154
VK_CHECK_COND(graph.val_is_tref(tensor_data));
154-
ValueRef tensor = graph.add_tensor_like(tensor_data, storage_type, layout);
155+
ValueRef tensor = graph.add_tensor_like(tensor_data, storage_type, layout, axis_map_layout);
155156
add_prepack_standard_node(graph, tensor_data, tensor);
156157
return tensor;
157158
}

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ ValueRef prepack_standard(
4848
const ValueRef tensor_data,
4949
const utils::StorageType storage_type,
5050
const utils::GPUMemoryLayout layout,
51-
const bool passthrough = false);
51+
const bool passthrough = false,
52+
const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap);
5253

5354
/*
5455
* Equivalent to `prepack_standard()` function, except the `storage_type` and

backends/vulkan/runtime/utils/StorageUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,16 @@ inline std::ostream& operator<<(
146146
return os;
147147
}
148148

149+
enum class AxisMapLayout : uint8_t {
150+
DEFAULT = 0u,
151+
OPTIMIZED = 1u,
152+
};
153+
154+
static constexpr AxisMapLayout kDefaultAxisMap =
155+
AxisMapLayout::DEFAULT;
156+
157+
static constexpr AxisMapLayout kOptimizedAxisMap =
158+
AxisMapLayout::OPTIMIZED;
159+
149160
} // namespace utils
150161
} // namespace vkcompute

0 commit comments

Comments
 (0)