13
13
namespace vkcompute {
14
14
namespace api {
15
15
16
- /*
17
- * Given the strides of a buffer-backed tensor, estimate the equivalent memory
18
- * layout enum value by identifying the fastest moving dimension.
19
- */
20
- utils::GPUMemoryLayout estimate_memory_layout (
21
- const std::vector<int64_t >& dim_order) {
22
- int64_t fastest_dim_whcn = dim_order.size () - 1 - dim_order.back ();
23
- if (fastest_dim_whcn >= 0 && fastest_dim_whcn < 3 ) {
24
- return utils::GPUMemoryLayout (fastest_dim_whcn);
25
- }
26
-
27
- // TODO(ssjia) find a way to gracefully recover from this case by i.e. adding
28
- // a UNKOWN GPUMemoryLayout. This is not high priority though because we don't
29
- // expect this to ever come up in practice.
30
- VK_THROW (" No compatible GPUMemoryLayout value" );
31
- }
32
-
33
16
std::vector<int64_t > calculate_dim_order (
34
17
const size_t ndim,
35
- const utils::GPUMemoryLayout memory_layout ) {
18
+ const int32_t packed_dim_whcn_idx ) {
36
19
// Special case for zero dim tensors
37
20
if (ndim == 0 ) {
38
21
return {0 };
39
22
}
40
23
std::vector<int64_t > dim_order (ndim);
41
- int64_t last_dim =
42
- ndim - utils::to_packed_dim_nchw_offset<int64_t >(memory_layout);
24
+ int64_t last_dim = ndim - 1 - packed_dim_whcn_idx;
43
25
44
26
int64_t cur_dim = 0 ;
45
27
for (int d = 0 ; d < ndim; ++d) {
@@ -149,7 +131,7 @@ std::vector<int64_t> unsqueeze_strides(
149
131
150
132
std::vector<int64_t > calculate_padded_sizes (
151
133
const std::vector<int64_t >& sizes,
152
- const utils::GPUMemoryLayout memory_layout ) {
134
+ const int32_t packed_dim_whcn_idx ) {
153
135
int64_t ndim = sizes.size ();
154
136
if (ndim == 0 ) {
155
137
ndim = 1 ;
@@ -163,8 +145,7 @@ std::vector<int64_t> calculate_padded_sizes(
163
145
}
164
146
165
147
// Pad the packed dim to the next multiple of 4.
166
- const int64_t dim_offset =
167
- utils::to_packed_dim_nchw_offset<int64_t >(memory_layout);
148
+ const int64_t dim_offset = packed_dim_whcn_idx + 1 ;
168
149
const int64_t padded_dim_size = utils::val_at (-dim_offset, sizes);
169
150
padded_sizes.at (ndim_up4 - dim_offset) = utils::align_up_4 (padded_dim_size);
170
151
@@ -174,7 +155,7 @@ std::vector<int64_t> calculate_padded_sizes(
174
155
utils::uvec3 calculate_image_extents (
175
156
const std::vector<int64_t >& padded_sizes,
176
157
const std::vector<int64_t >& axis_map,
177
- const utils::GPUMemoryLayout memory_layout ) {
158
+ const int32_t packed_dim_whcn_idx ) {
178
159
VK_CHECK_COND (padded_sizes.size () == 4 );
179
160
VK_CHECK_COND (axis_map.size () == 4 );
180
161
@@ -195,21 +176,8 @@ utils::uvec3 calculate_image_extents(
195
176
// Multiply the extents of the batch axis by the batch size.
196
177
extents[batch_axis] *= padded_sizes.at (0 );
197
178
198
- switch (memory_layout) {
199
- case utils::kWidthPacked :
200
- VK_CHECK_COND (extents[axis_map.at (0 )] % 4 == 0 );
201
- extents[axis_map.at (0 )] /= 4 ;
202
- break ;
203
- case utils::kHeightPacked :
204
- VK_CHECK_COND (extents[axis_map.at (1 )] % 4 == 0 );
205
- extents[axis_map.at (1 )] /= 4 ;
206
- break ;
207
- case utils::kChannelsPacked :
208
- VK_CHECK_COND (extents[axis_map.at (2 )] % 4 == 0 );
209
- extents[axis_map.at (2 )] /= 4 ;
210
- break ;
211
- }
212
-
179
+ VK_CHECK_COND (extents[axis_map.at (packed_dim_whcn_idx)] % 4 == 0 );
180
+ extents[axis_map.at (packed_dim_whcn_idx)] /= 4 ;
213
181
return extents;
214
182
}
215
183
@@ -285,15 +253,15 @@ vkapi::VulkanBuffer allocate_buffer(
285
253
vTensorStorage::vTensorStorage (
286
254
Context* const context,
287
255
const utils::StorageType storage_type,
288
- const utils::GPUMemoryLayout gpu_memory_layout,
289
256
const std::vector<int64_t >& axis_map,
257
+ const int32_t packed_dim_whcn_idx,
290
258
const std::vector<int64_t >& padded_sizes,
291
259
const vkapi::ScalarType dtype,
292
260
const bool allocate_memory)
293
261
: context_(context),
294
262
storage_type_{storage_type},
295
263
image_extents_ (
296
- calculate_image_extents (padded_sizes, axis_map, gpu_memory_layout )),
264
+ calculate_image_extents (padded_sizes, axis_map, packed_dim_whcn_idx )),
297
265
buffer_length_{utils::multiply_integers (padded_sizes)},
298
266
buffer_offset_{0 },
299
267
image_ (allocate_image(
@@ -408,14 +376,15 @@ vTensor::vTensor(
408
376
const utils::GPUMemoryLayout memory_layout,
409
377
const bool allocate_memory)
410
378
: dtype_(dtype),
411
- memory_layout_(memory_layout),
412
379
// Calculate tensor metadata
413
380
sizes_(sizes.begin(), sizes.end()),
414
- dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
381
+ packed_dim_whcn_idx_(
382
+ utils::to_packed_dim_whcn_idx<int32_t >(memory_layout)),
383
+ dim_order_(calculate_dim_order(sizes_.size(), packed_dim_whcn_idx_)),
415
384
axis_map_(default_axis_map()),
416
385
strides_(calculate_strides(sizes, dim_order_)),
417
386
numel_(utils::multiply_integers(sizes_)),
418
- padded_sizes_{calculate_padded_sizes (sizes, memory_layout_ )},
387
+ padded_sizes_{calculate_padded_sizes (sizes, packed_dim_whcn_idx_ )},
419
388
unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
420
389
padded_numel_ (utils::multiply_integers(padded_sizes_)),
421
390
logical_limits_{{0 , 0 , 0 }},
@@ -429,8 +398,8 @@ vTensor::vTensor(
429
398
storage_(
430
399
context,
431
400
storage_type,
432
- memory_layout_,
433
401
axis_map_,
402
+ packed_dim_whcn_idx_,
434
403
padded_sizes_,
435
404
dtype_,
436
405
allocate_memory) {
@@ -451,9 +420,9 @@ vTensor::vTensor(
451
420
452
421
vTensor::vTensor (const vTensor& other)
453
422
: dtype_(other.dtype_),
454
- memory_layout_(other.memory_layout_),
455
423
// Copy tensor size metadata
456
424
sizes_(other.sizes_.begin(), other.sizes_.end()),
425
+ packed_dim_whcn_idx_{other.packed_dim_whcn_idx_ },
457
426
dim_order_ (other.dim_order_.begin(), other.dim_order_.end()),
458
427
axis_map_(other.axis_map_.begin(), other.axis_map_.end()),
459
428
strides_(other.strides_.begin(), other.strides_.end()),
@@ -479,14 +448,14 @@ vTensor::vTensor(
479
448
const std::vector<int64_t >& dim_order,
480
449
const int64_t offset_numel)
481
450
: dtype_(other.dtype_),
482
- memory_layout_(estimate_memory_layout(dim_order)),
483
451
// Copy tensor size metadata
484
452
sizes_(sizes.begin(), sizes.end()),
453
+ packed_dim_whcn_idx_(other.packed_dim_whcn_idx_),
485
454
dim_order_(dim_order.begin(), dim_order.end()),
486
455
axis_map_(default_axis_map()),
487
456
strides_(calculate_strides(sizes_, dim_order_)),
488
457
numel_(utils::multiply_integers(sizes_)),
489
- padded_sizes_{calculate_padded_sizes (sizes, memory_layout_ )},
458
+ padded_sizes_{calculate_padded_sizes (sizes, packed_dim_whcn_idx_ )},
490
459
unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
491
460
padded_numel_ (utils::multiply_integers(padded_sizes_)),
492
461
logical_limits_(other.logical_limits_),
@@ -542,6 +511,19 @@ void vTensor::set_logical_limits(const utils::uvec3& image_extents) {
542
511
logical_limits_.limits [2 ] = image_extents[axis_map_.at (2 )];
543
512
}
544
513
514
+ utils::GPUMemoryLayout vTensor::estimate_memory_layout () const {
515
+ switch (packed_dim_whcn_idx_) {
516
+ case WHCN::kWidthDim :
517
+ return utils::kWidthPacked ;
518
+ case WHCN::kHeightDim :
519
+ return utils::kHeightPacked ;
520
+ case WHCN::kChannelsDim :
521
+ return utils::kChannelsPacked ;
522
+ default :
523
+ VK_THROW (" Invalid packed dim" );
524
+ }
525
+ }
526
+
545
527
const vkapi::BufferBindInfo vTensor::sizes_ubo () {
546
528
if (!sizes_uniform_.buffer ()) {
547
529
sizes_uniform_ =
@@ -618,21 +600,16 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
618
600
619
601
void vTensor::update_metadata () {
620
602
strides_ = calculate_strides (sizes_, dim_order_);
621
- // Only update the memory layout for buffer-backed tensors. Strides are
622
- // meaningless for texture-backed tensors and do not impact the memory layout.
623
- if (storage_type () == utils::kBuffer ) {
624
- memory_layout_ = estimate_memory_layout (dim_order_);
625
- }
626
603
numel_ = utils::multiply_integers (sizes_);
627
604
628
- padded_sizes_ = calculate_padded_sizes (sizes_, memory_layout_ );
605
+ padded_sizes_ = calculate_padded_sizes (sizes_, packed_dim_whcn_idx_ );
629
606
unsqueezed_strides_ = unsqueeze_strides (strides_, numel_);
630
607
padded_numel_ = utils::multiply_integers (padded_sizes_);
631
608
632
609
// Calculate the image extents that would have been used to allocate a texture
633
610
// withthe current sizes, and use that to set the logical limits.
634
611
set_logical_limits (
635
- calculate_image_extents (padded_sizes_, axis_map_, memory_layout_ ));
612
+ calculate_image_extents (padded_sizes_, axis_map_, packed_dim_whcn_idx_ ));
636
613
637
614
if (sizes_uniform_.buffer ()) {
638
615
sizes_uniform_.update (utils::make_whcn_ivec4 (sizes_));
@@ -656,7 +633,7 @@ void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
656
633
// For texture storage check that the current texture is large enough for
657
634
// the new sizes of the tensor.
658
635
utils::uvec3 virtual_extents =
659
- calculate_image_extents (padded_sizes_, axis_map_, memory_layout_ );
636
+ calculate_image_extents (padded_sizes_, axis_map_, packed_dim_whcn_idx_ );
660
637
661
638
bool valid_resize = virtual_extents[0 ] <= storage_.image_extents_ [0 ];
662
639
valid_resize =
@@ -725,23 +702,23 @@ void transpose_dim_order_inplace(
725
702
726
703
void vTensor::virtual_transpose (const int64_t dim0, const int64_t dim1) {
727
704
std::iter_swap (sizes_.begin () + dim0, sizes_.begin () + dim1);
705
+
706
+ const int dim0_whcn = sizes_.size () - 1 - dim0;
707
+ const int dim1_whcn = sizes_.size () - 1 - dim1;
708
+ if (packed_dim_whcn_idx_ == dim0_whcn) {
709
+ packed_dim_whcn_idx_ = dim1_whcn;
710
+ }
711
+ if (packed_dim_whcn_idx_ == dim1_whcn) {
712
+ packed_dim_whcn_idx_ = dim0_whcn;
713
+ }
714
+
728
715
if (storage_type () == utils::kBuffer ) {
729
716
transpose_dim_order_inplace (dim_order_, dim0, dim1);
730
717
} else {
731
- const int dim0_whcn = sizes_.size () - 1 - dim0;
732
- const int dim1_whcn = sizes_.size () - 1 - dim1;
733
718
// Cannot transpose batch dimension for texture storage
734
719
VK_CHECK_COND (dim0_whcn < 3 && dim1_whcn < 3 );
735
-
736
720
std::iter_swap (
737
721
axis_map_.begin () + dim0_whcn, axis_map_.begin () + dim1_whcn);
738
-
739
- if (packed_dim_whcn_idx () == dim0_whcn) {
740
- memory_layout_ = utils::GPUMemoryLayout (dim1_whcn);
741
- }
742
- if (packed_dim_whcn_idx () == dim1_whcn) {
743
- memory_layout_ = utils::GPUMemoryLayout (dim0_whcn);
744
- }
745
722
}
746
723
update_metadata ();
747
724
}
0 commit comments