@@ -80,6 +80,31 @@ std::vector<int64_t> calculate_strides(
80
80
return strides;
81
81
}
82
82
83
+ /*
84
+ * Axis mapping is somewhat analogous to strides for texture backed tensors.
85
+ *
86
+ * The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
87
+ * The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
88
+ * axis that corresponds to the width, height, and channels dimension of the
89
+ * tensor. Thus the axis mapping can be considered to be in WHCN dimension
90
+ * order.
91
+ *
92
+ * The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor
93
+ * dimension along which batches will be concatenated. To determine which image
94
+ * texture axis is used for the concatenation, a double lookup will need to be
95
+ * performed (axis_mapping.at(axis_mapping.at(3))).
96
+ *
97
+ * The axis mapping allows for permuted views of texture-backed tensors.
98
+ */
99
+ std::vector<int64_t > default_axis_mapping () {
100
+ // Currently, all compute shaders have an assumption that the channels dim is
101
+ // used to combine with the batch dim of a tensor. However, once dim mapping
102
+ // is integrated into the tensor indexing logic for each compute shader, we
103
+ // can be more flexible with mapping the batch dim to different texture axes
104
+ // in order to improve performance or memory footprint.
105
+ return {0 , 1 , 2 , 2 };
106
+ }
107
+
83
108
bool dim_order_is_valid (const std::vector<int64_t >& dim_order) {
84
109
int64_t sum = 0 ;
85
110
for (size_t i = 0 ; i < dim_order.size (); ++i) {
@@ -137,30 +162,44 @@ std::vector<int64_t> calculate_padded_sizes(
137
162
138
163
utils::uvec3 calculate_image_extents (
139
164
const std::vector<int64_t >& padded_sizes,
165
+ const std::vector<int64_t >& axis_mapping,
140
166
const utils::GPUMemoryLayout memory_layout) {
141
167
VK_CHECK_COND (padded_sizes.size () == 4 );
168
+ VK_CHECK_COND (axis_mapping.size () == 4 );
142
169
143
- uint32_t N = utils::safe_downcast<uint32_t >(padded_sizes.at (0 ));
144
- uint32_t C = utils::safe_downcast<uint32_t >(padded_sizes.at (1 ));
145
- uint32_t H = utils::safe_downcast<uint32_t >(padded_sizes.at (2 ));
146
- uint32_t W = utils::safe_downcast<uint32_t >(padded_sizes.at (3 ));
170
+ utils::uvec3 extents ({1 , 1 , 1 });
171
+ // First three elements of axis_mapping indicate which (X,Y,Z) image axis the
172
+ // width, height, and channels dim of the tensor maps to.
173
+ for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
174
+ const int64_t axis = axis_mapping.at (whcn_dim);
175
+ const int64_t dim = padded_sizes.size () - 1 - whcn_dim;
176
+ extents[axis] = utils::safe_downcast<uint32_t >(padded_sizes.at (dim));
177
+ }
178
+
179
+ // axis_mapping[3] indicates the WHCN index of the dimension used for batch
180
+ // concatenation. Thus a double lookup is required to determine the image axis
181
+ // used for batch concatenation.
182
+ const int64_t concatted_whcn_dim = axis_mapping.at (3 );
183
+ const int64_t batch_axis = axis_mapping.at (concatted_whcn_dim);
184
+ // Multiply the extents of the batch axis by the batch size.
185
+ extents[batch_axis] *= padded_sizes.at (0 );
147
186
148
187
switch (memory_layout) {
149
188
case utils::kWidthPacked :
150
- VK_CHECK_COND (W % 4 == 0 );
151
- W /= 4 ;
189
+ VK_CHECK_COND (extents[ 0 ] % 4 == 0 );
190
+ extents[ 0 ] /= 4 ;
152
191
break ;
153
192
case utils::kHeightPacked :
154
- VK_CHECK_COND (H % 4 == 0 );
155
- H /= 4 ;
193
+ VK_CHECK_COND (extents[ 1 ] % 4 == 0 );
194
+ extents[ 1 ] /= 4 ;
156
195
break ;
157
196
case utils::kChannelsPacked :
158
- VK_CHECK_COND (C % 4 == 0 );
159
- C /= 4 ;
197
+ VK_CHECK_COND (extents[ 2 ] % 4 == 0 );
198
+ extents[ 2 ] /= 4 ;
160
199
break ;
161
200
}
162
201
163
- return {W, H, C * N} ;
202
+ return extents ;
164
203
}
165
204
166
205
//
@@ -176,9 +215,10 @@ vTensor::vTensor(
176
215
const bool allocate_memory)
177
216
: dtype_(dtype),
178
217
memory_layout_ (memory_layout),
179
- // Calculate tensor size metadata
218
+ // Calculate tensor metadata
180
219
sizes_(sizes.begin(), sizes.end()),
181
220
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
221
+ axis_mapping_(default_axis_mapping()),
182
222
strides_(calculate_strides(sizes, dim_order_)),
183
223
numel_(utils::multiply_integers(sizes_)),
184
224
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
@@ -189,12 +229,14 @@ vTensor::vTensor(
189
229
sizes_uniform_ (),
190
230
strides_uniform_ (),
191
231
numel_uniform_ (),
232
+ axis_mapping_uniform_ (),
192
233
texture_limits_uniform_ (),
193
234
// Construct Tensor storage
194
235
storage_ (
195
236
context,
196
237
storage_type,
197
238
memory_layout_,
239
+ axis_mapping_,
198
240
padded_sizes_,
199
241
dtype_,
200
242
allocate_memory) {
@@ -222,6 +264,7 @@ vTensor::vTensor(const vTensor& other)
222
264
// Copy tensor size metadata
223
265
sizes_(other.sizes_.begin(), other.sizes_.end()),
224
266
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
267
+ axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()),
225
268
strides_(other.strides_.begin(), other.strides_.end()),
226
269
numel_(other.numel_),
227
270
padded_sizes_{other.padded_sizes_ .begin (), other.padded_sizes_ .end ()},
@@ -234,6 +277,7 @@ vTensor::vTensor(const vTensor& other)
234
277
sizes_uniform_ (),
235
278
strides_uniform_ (),
236
279
numel_uniform_ (),
280
+ axis_mapping_uniform_ (),
237
281
texture_limits_uniform_ (),
238
282
// Copy Tensor storage
239
283
storage_ (other.storage_) {}
@@ -248,6 +292,7 @@ vTensor::vTensor(
248
292
// Copy tensor size metadata
249
293
sizes_(sizes.begin(), sizes.end()),
250
294
dim_order_(dim_order.begin(), dim_order.end()),
295
+ axis_mapping_(default_axis_mapping()),
251
296
strides_(calculate_strides(sizes_, dim_order_)),
252
297
numel_(utils::multiply_integers(sizes_)),
253
298
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
@@ -258,6 +303,7 @@ vTensor::vTensor(
258
303
sizes_uniform_ (),
259
304
strides_uniform_ (),
260
305
numel_uniform_ (),
306
+ axis_mapping_uniform_ (),
261
307
texture_limits_uniform_ (),
262
308
// Copy Tensor storage
263
309
storage_ (other.storage_, vkapi::element_size(dtype_) * offset_numel) {
@@ -315,6 +361,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
315
361
return vkapi::BufferBindInfo (strides_uniform_.buffer ());
316
362
}
317
363
364
+ const vkapi::BufferBindInfo vTensor::axis_mapping_ubo () {
365
+ if (!axis_mapping_uniform_.buffer ()) {
366
+ axis_mapping_uniform_ =
367
+ ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_mapping_));
368
+ }
369
+ return vkapi::BufferBindInfo (axis_mapping_uniform_.buffer ());
370
+ }
371
+
318
372
const vkapi::BufferBindInfo vTensor::texture_limits_ubo () {
319
373
if (!texture_limits_uniform_.buffer ()) {
320
374
texture_limits_uniform_ = ParamsBuffer (storage_.context_ , texture_limits_);
@@ -376,11 +430,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
376
430
}
377
431
}
378
432
379
- void vTensor::update_metadata (
380
- const std::vector<int64_t >& new_sizes,
381
- const std::vector<int64_t >& new_dim_order) {
382
- sizes_ = new_sizes;
383
- dim_order_ = new_dim_order;
433
+ void vTensor::update_metadata () {
384
434
strides_ = calculate_strides (sizes_, dim_order_);
385
435
// Only update the memory layout for buffer-backed tensors. Strides are
386
436
// meaningless for texture-backed tensors and do not impact the memory layout.
@@ -396,7 +446,7 @@ void vTensor::update_metadata(
396
446
// Calculate the extents of the image texture that would have been required
397
447
// for a tensor of the new sizes.
398
448
utils::uvec3 virtual_extents =
399
- calculate_image_extents (padded_sizes_, memory_layout_);
449
+ calculate_image_extents (padded_sizes_, axis_mapping_, memory_layout_);
400
450
401
451
// Update the texture limits to reflect the new virtual extents.
402
452
texture_limits_.limits = utils::ivec3{
@@ -407,23 +457,26 @@ void vTensor::update_metadata(
407
457
if (sizes_uniform_.buffer ()) {
408
458
sizes_uniform_.update (utils::make_whcn_ivec4 (sizes_));
409
459
}
410
- if (texture_limits_uniform_.buffer ()) {
411
- texture_limits_uniform_.update (texture_limits_);
412
- }
413
460
if (strides_uniform_.buffer ()) {
414
461
strides_uniform_.update (utils::make_whcn_ivec4 (unsqueezed_strides_));
415
462
}
416
463
if (numel_uniform_.buffer ()) {
417
464
numel_uniform_.update (numel_);
418
465
}
466
+ if (axis_mapping_uniform_.buffer ()) {
467
+ axis_mapping_uniform_.update (utils::make_ivec4 (axis_mapping_));
468
+ }
469
+ if (texture_limits_uniform_.buffer ()) {
470
+ texture_limits_uniform_.update (texture_limits_);
471
+ }
419
472
}
420
473
421
474
void vTensor::check_sizes (const std::vector<int64_t >& sizes) const {
422
475
if (storage_type () != utils::kBuffer ) {
423
476
// For texture storage check that the current texture is large enough for
424
477
// the new sizes of the tensor.
425
478
utils::uvec3 virtual_extents =
426
- calculate_image_extents (padded_sizes_, memory_layout_);
479
+ calculate_image_extents (padded_sizes_, axis_mapping_, memory_layout_);
427
480
428
481
bool valid_resize = virtual_extents[0 ] <= image_extents ()[0 ];
429
482
valid_resize = valid_resize && virtual_extents[1 ] <= image_extents ()[1 ];
@@ -454,7 +507,9 @@ void vTensor::virtual_reconfigure(
454
507
VK_CHECK_COND (dim_order_is_valid (new_dim_order));
455
508
456
509
check_sizes (new_sizes);
457
- update_metadata (new_sizes, new_dim_order);
510
+ sizes_ = new_sizes;
511
+ dim_order_ = new_dim_order;
512
+ update_metadata ();
458
513
}
459
514
460
515
void vTensor::virtual_resize (const std::vector<int64_t >& new_sizes) {
@@ -463,13 +518,16 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
463
518
" new sizes cannot modify the dimensionality of the tensor " );
464
519
465
520
check_sizes (new_sizes);
466
- update_metadata (new_sizes, dim_order_);
521
+ sizes_ = new_sizes;
522
+ update_metadata ();
467
523
}
468
524
469
525
void vTensor::reallocate (const std::vector<int64_t >& new_sizes) {
470
- update_metadata (new_sizes, dim_order_);
526
+ sizes_ = new_sizes;
527
+ update_metadata ();
471
528
storage_.discard_and_reallocate (
472
529
calculate_padded_sizes (new_sizes, memory_layout_),
530
+ axis_mapping_,
473
531
memory_layout_,
474
532
dtype_);
475
533
}
@@ -547,12 +605,16 @@ vTensorStorage::vTensorStorage(
547
605
Context* const context,
548
606
const utils::StorageType storage_type,
549
607
const utils::GPUMemoryLayout gpu_memory_layout,
608
+ const std::vector<int64_t >& axis_mapping,
550
609
const std::vector<int64_t >& padded_sizes,
551
610
const vkapi::ScalarType dtype,
552
611
const bool allocate_memory)
553
612
: context_(context),
554
613
storage_type_{storage_type},
555
- image_extents_ (calculate_image_extents(padded_sizes, gpu_memory_layout)),
614
+ image_extents_ (calculate_image_extents(
615
+ padded_sizes,
616
+ axis_mapping,
617
+ gpu_memory_layout)),
556
618
buffer_length_{utils::multiply_integers (padded_sizes)},
557
619
buffer_offset_{0 },
558
620
image_ (allocate_image(
@@ -665,14 +727,16 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
665
727
666
728
void vTensorStorage::discard_and_reallocate (
667
729
const std::vector<int64_t >& padded_sizes,
730
+ const std::vector<int64_t >& axis_mapping,
668
731
const utils::GPUMemoryLayout gpu_memory_layout,
669
732
const vkapi::ScalarType dtype) {
670
733
const bool image_owns_memory = image_.owns_memory ();
671
734
const bool buffer_owns_memory = buffer_.owns_memory ();
672
735
673
736
flush ();
674
737
675
- image_extents_ = calculate_image_extents (padded_sizes, gpu_memory_layout);
738
+ image_extents_ =
739
+ calculate_image_extents (padded_sizes, axis_mapping, gpu_memory_layout);
676
740
image_ = allocate_image (
677
741
context_,
678
742
image_extents_,
0 commit comments