@@ -80,6 +80,42 @@ std::vector<int64_t> calculate_strides(
80
80
return strides;
81
81
}
82
82
83
+ /*
84
+ * Axis mapping is somewhat analogous to strides for texture backed tensors.
85
+ *
86
+ * The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
87
+ * The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
88
+ * axis that corresponds to the width, height, and channels dimension of the
89
+ * tensor. Thus the axis mapping can be considered to be in WHCN dimension
90
+ * order.
91
+ *
92
+ * The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
93
+ * dimension along which batches will be concatenated. This dimension can be
94
+ * referred to as the "inner dimension" To determine which image texture axis is
95
+ * used for the concatenation, a double lookup will need to be performed
96
+ * (axis_mapping.at(axis_mapping.at(3))).
97
+ *
98
+ * The reason for strucuring axis mapping this way is because for the batch dim,
99
+ * two things need to be easily derived:
100
+ *
101
+ * 1. The dim idx of the inner dimension, so that the size of the inner
102
+ * dimension can be easily determined.
103
+ * 2. The texture axis used to concatenate batches
104
+ *
105
+ * By storing the dim index of the inner dimension instead of the texture axis
106
+ * it maps to, both pieces of information are readily available.
107
+ *
108
+ * The axis mapping allows for permuted views of texture-backed tensors.
109
+ */
110
+ std::vector<int64_t > default_axis_mapping () {
111
+ // Currently, all compute shaders have an assumption that the channels dim is
112
+ // used to combine with the batch dim of a tensor. However, once dim mapping
113
+ // is integrated into the tensor indexing logic for each compute shader, we
114
+ // can be more flexible with mapping the batch dim to different texture axes
115
+ // in order to improve performance or memory footprint.
116
+ return {0 , 1 , 2 , 2 };
117
+ }
118
+
83
119
bool dim_order_is_valid (const std::vector<int64_t >& dim_order) {
84
120
int64_t sum = 0 ;
85
121
for (size_t i = 0 ; i < dim_order.size (); ++i) {
@@ -137,30 +173,44 @@ std::vector<int64_t> calculate_padded_sizes(
137
173
138
174
utils::uvec3 calculate_image_extents (
139
175
const std::vector<int64_t >& padded_sizes,
176
+ const std::vector<int64_t >& axis_mapping,
140
177
const utils::GPUMemoryLayout memory_layout) {
141
178
VK_CHECK_COND (padded_sizes.size () == 4 );
179
+ VK_CHECK_COND (axis_mapping.size () == 4 );
180
+
181
+ utils::uvec3 extents ({1 , 1 , 1 });
182
+ // First three elements of axis_mapping indicate which (X,Y,Z) image axis the
183
+ // width, height, and channels dim of the tensor maps to.
184
+ for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
185
+ const int64_t axis = axis_mapping.at (whcn_dim);
186
+ const int64_t dim = padded_sizes.size () - 1 - whcn_dim;
187
+ extents[axis] = utils::safe_downcast<uint32_t >(padded_sizes.at (dim));
188
+ }
142
189
143
- uint32_t N = utils::safe_downcast<uint32_t >(padded_sizes.at (0 ));
144
- uint32_t C = utils::safe_downcast<uint32_t >(padded_sizes.at (1 ));
145
- uint32_t H = utils::safe_downcast<uint32_t >(padded_sizes.at (2 ));
146
- uint32_t W = utils::safe_downcast<uint32_t >(padded_sizes.at (3 ));
190
+ // axis_mapping[3] indicates the WHCN index of the dimension used for batch
191
+ // concatenation. Thus a double lookup is required to determine the image axis
192
+ // used for batch concatenation.
193
+ const int64_t concatted_whcn_dim = axis_mapping.at (3 );
194
+ const int64_t batch_axis = axis_mapping.at (concatted_whcn_dim);
195
+ // Multiply the extents of the batch axis by the batch size.
196
+ extents[batch_axis] *= padded_sizes.at (0 );
147
197
148
198
switch (memory_layout) {
149
199
case utils::kWidthPacked :
150
- VK_CHECK_COND (W % 4 == 0 );
151
- W /= 4 ;
200
+ VK_CHECK_COND (extents[ 0 ] % 4 == 0 );
201
+ extents[ 0 ] /= 4 ;
152
202
break ;
153
203
case utils::kHeightPacked :
154
- VK_CHECK_COND (H % 4 == 0 );
155
- H /= 4 ;
204
+ VK_CHECK_COND (extents[ 1 ] % 4 == 0 );
205
+ extents[ 1 ] /= 4 ;
156
206
break ;
157
207
case utils::kChannelsPacked :
158
- VK_CHECK_COND (C % 4 == 0 );
159
- C /= 4 ;
208
+ VK_CHECK_COND (extents[ 2 ] % 4 == 0 );
209
+ extents[ 2 ] /= 4 ;
160
210
break ;
161
211
}
162
212
163
- return {W, H, C * N} ;
213
+ return extents ;
164
214
}
165
215
166
216
//
@@ -176,9 +226,10 @@ vTensor::vTensor(
176
226
const bool allocate_memory)
177
227
: dtype_(dtype),
178
228
memory_layout_ (memory_layout),
179
- // Calculate tensor size metadata
229
+ // Calculate tensor metadata
180
230
sizes_(sizes.begin(), sizes.end()),
181
231
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232
+ axis_mapping_(default_axis_mapping()),
182
233
strides_(calculate_strides(sizes, dim_order_)),
183
234
numel_(utils::multiply_integers(sizes_)),
184
235
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
@@ -189,12 +240,14 @@ vTensor::vTensor(
189
240
sizes_uniform_ (),
190
241
strides_uniform_ (),
191
242
numel_uniform_ (),
243
+ axis_mapping_uniform_ (),
192
244
texture_limits_uniform_ (),
193
245
// Construct Tensor storage
194
246
storage_ (
195
247
context,
196
248
storage_type,
197
249
memory_layout_,
250
+ axis_mapping_,
198
251
padded_sizes_,
199
252
dtype_,
200
253
allocate_memory) {
@@ -222,6 +275,7 @@ vTensor::vTensor(const vTensor& other)
222
275
// Copy tensor size metadata
223
276
sizes_(other.sizes_.begin(), other.sizes_.end()),
224
277
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278
+ axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
225
279
strides_(other.strides_.begin(), other.strides_.end()),
226
280
numel_(other.numel_),
227
281
padded_sizes_{other.padded_sizes_ .begin (), other.padded_sizes_ .end ()},
@@ -234,6 +288,7 @@ vTensor::vTensor(const vTensor& other)
234
288
sizes_uniform_ (),
235
289
strides_uniform_ (),
236
290
numel_uniform_ (),
291
+ axis_mapping_uniform_ (),
237
292
texture_limits_uniform_ (),
238
293
// Copy Tensor storage
239
294
storage_ (other.storage_) {}
@@ -248,6 +303,7 @@ vTensor::vTensor(
248
303
// Copy tensor size metadata
249
304
sizes_(sizes.begin(), sizes.end()),
250
305
dim_order_(dim_order.begin(), dim_order.end()),
306
+ axis_mapping_(default_axis_mapping()),
251
307
strides_(calculate_strides(sizes_, dim_order_)),
252
308
numel_(utils::multiply_integers(sizes_)),
253
309
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
@@ -258,6 +314,7 @@ vTensor::vTensor(
258
314
sizes_uniform_ (),
259
315
strides_uniform_ (),
260
316
numel_uniform_ (),
317
+ axis_mapping_uniform_ (),
261
318
texture_limits_uniform_ (),
262
319
// Copy Tensor storage
263
320
storage_ (other.storage_, vkapi::element_size(dtype_) * offset_numel) {
@@ -315,6 +372,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
315
372
return vkapi::BufferBindInfo (strides_uniform_.buffer ());
316
373
}
317
374
375
+ const vkapi::BufferBindInfo vTensor::axis_mapping_ubo () {
376
+ if (!axis_mapping_uniform_.buffer ()) {
377
+ axis_mapping_uniform_ =
378
+ ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_mapping_));
379
+ }
380
+ return vkapi::BufferBindInfo (axis_mapping_uniform_.buffer ());
381
+ }
382
+
318
383
const vkapi::BufferBindInfo vTensor::texture_limits_ubo () {
319
384
if (!texture_limits_uniform_.buffer ()) {
320
385
texture_limits_uniform_ = ParamsBuffer (storage_.context_ , texture_limits_);
@@ -376,11 +441,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
376
441
}
377
442
}
378
443
379
- void vTensor::update_metadata (
380
- const std::vector<int64_t >& new_sizes,
381
- const std::vector<int64_t >& new_dim_order) {
382
- sizes_ = new_sizes;
383
- dim_order_ = new_dim_order;
444
+ void vTensor::update_metadata () {
384
445
strides_ = calculate_strides (sizes_, dim_order_);
385
446
// Only update the memory layout for buffer-backed tensors. Strides are
386
447
// meaningless for texture-backed tensors and do not impact the memory layout.
@@ -396,7 +457,7 @@ void vTensor::update_metadata(
396
457
// Calculate the extents of the image texture that would have been required
397
458
// for a tensor of the new sizes.
398
459
utils::uvec3 virtual_extents =
399
- calculate_image_extents (padded_sizes_, memory_layout_);
460
+ calculate_image_extents (padded_sizes_, axis_mapping_, memory_layout_);
400
461
401
462
// Update the texture limits to reflect the new virtual extents.
402
463
texture_limits_.limits = utils::ivec3{
@@ -407,23 +468,26 @@ void vTensor::update_metadata(
407
468
if (sizes_uniform_.buffer ()) {
408
469
sizes_uniform_.update (utils::make_whcn_ivec4 (sizes_));
409
470
}
410
- if (texture_limits_uniform_.buffer ()) {
411
- texture_limits_uniform_.update (texture_limits_);
412
- }
413
471
if (strides_uniform_.buffer ()) {
414
472
strides_uniform_.update (utils::make_whcn_ivec4 (unsqueezed_strides_));
415
473
}
416
474
if (numel_uniform_.buffer ()) {
417
475
numel_uniform_.update (numel_);
418
476
}
477
+ if (axis_mapping_uniform_.buffer ()) {
478
+ axis_mapping_uniform_.update (utils::make_ivec4 (axis_mapping_));
479
+ }
480
+ if (texture_limits_uniform_.buffer ()) {
481
+ texture_limits_uniform_.update (texture_limits_);
482
+ }
419
483
}
420
484
421
485
void vTensor::check_sizes (const std::vector<int64_t >& sizes) const {
422
486
if (storage_type () != utils::kBuffer ) {
423
487
// For texture storage check that the current texture is large enough for
424
488
// the new sizes of the tensor.
425
489
utils::uvec3 virtual_extents =
426
- calculate_image_extents (padded_sizes_, memory_layout_);
490
+ calculate_image_extents (padded_sizes_, axis_mapping_, memory_layout_);
427
491
428
492
bool valid_resize = virtual_extents[0 ] <= image_extents ()[0 ];
429
493
valid_resize = valid_resize && virtual_extents[1 ] <= image_extents ()[1 ];
@@ -454,7 +518,9 @@ void vTensor::virtual_reconfigure(
454
518
VK_CHECK_COND (dim_order_is_valid (new_dim_order));
455
519
456
520
check_sizes (new_sizes);
457
- update_metadata (new_sizes, new_dim_order);
521
+ sizes_ = new_sizes;
522
+ dim_order_ = new_dim_order;
523
+ update_metadata ();
458
524
}
459
525
460
526
void vTensor::virtual_resize (const std::vector<int64_t >& new_sizes) {
@@ -463,13 +529,16 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
463
529
" new sizes cannot modify the dimensionality of the tensor " );
464
530
465
531
check_sizes (new_sizes);
466
- update_metadata (new_sizes, dim_order_);
532
+ sizes_ = new_sizes;
533
+ update_metadata ();
467
534
}
468
535
469
536
void vTensor::reallocate (const std::vector<int64_t >& new_sizes) {
470
- update_metadata (new_sizes, dim_order_);
537
+ sizes_ = new_sizes;
538
+ update_metadata ();
471
539
storage_.discard_and_reallocate (
472
540
calculate_padded_sizes (new_sizes, memory_layout_),
541
+ axis_mapping_,
473
542
memory_layout_,
474
543
dtype_);
475
544
}
@@ -547,12 +616,16 @@ vTensorStorage::vTensorStorage(
547
616
Context* const context,
548
617
const utils::StorageType storage_type,
549
618
const utils::GPUMemoryLayout gpu_memory_layout,
619
+ const std::vector<int64_t >& axis_mapping,
550
620
const std::vector<int64_t >& padded_sizes,
551
621
const vkapi::ScalarType dtype,
552
622
const bool allocate_memory)
553
623
: context_(context),
554
624
storage_type_{storage_type},
555
- image_extents_ (calculate_image_extents(padded_sizes, gpu_memory_layout)),
625
+ image_extents_ (calculate_image_extents(
626
+ padded_sizes,
627
+ axis_mapping,
628
+ gpu_memory_layout)),
556
629
buffer_length_{utils::multiply_integers (padded_sizes)},
557
630
buffer_offset_{0 },
558
631
image_ (allocate_image(
@@ -665,14 +738,16 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
665
738
666
739
void vTensorStorage::discard_and_reallocate (
667
740
const std::vector<int64_t >& padded_sizes,
741
+ const std::vector<int64_t >& axis_mapping,
668
742
const utils::GPUMemoryLayout gpu_memory_layout,
669
743
const vkapi::ScalarType dtype) {
670
744
const bool image_owns_memory = image_.owns_memory ();
671
745
const bool buffer_owns_memory = buffer_.owns_memory ();
672
746
673
747
flush ();
674
748
675
- image_extents_ = calculate_image_extents (padded_sizes, gpu_memory_layout);
749
+ image_extents_ =
750
+ calculate_image_extents (padded_sizes, axis_mapping, gpu_memory_layout);
676
751
image_ = allocate_image (
677
752
context_,
678
753
image_extents_,
0 commit comments