@@ -89,11 +89,11 @@ std::vector<int64_t> calculate_strides(
89
89
* tensor. Thus the axis mapping can be considered to be in WHCN dimension
90
90
* order.
91
91
*
92
- * The last value `axis_mapping .at(3)` indicates the WHCN index of the tensor
92
+ * The last value `axis_map .at(3)` indicates the WHCN index of the tensor
93
93
* dimension along which batches will be concatenated. This dimension can be
94
94
* referred to as the "inner dimension" To determine which image texture axis is
95
95
* used for the concatenation, a double lookup will need to be performed
96
- * (axis_mapping .at(axis_mapping .at(3))).
96
+ * (axis_map .at(axis_map .at(3))).
97
97
*
98
98
* The reason for strucuring axis mapping this way is because for the batch dim,
99
99
* two things need to be easily derived:
@@ -107,7 +107,7 @@ std::vector<int64_t> calculate_strides(
107
107
*
108
108
* The axis mapping allows for permuted views of texture-backed tensors.
109
109
*/
110
- std::vector<int64_t > default_axis_mapping () {
110
+ std::vector<int64_t > default_axis_map () {
111
111
// Currently, all compute shaders have an assumption that the channels dim is
112
112
// used to combine with the batch dim of a tensor. However, once dim mapping
113
113
// is integrated into the tensor indexing logic for each compute shader, we
@@ -173,40 +173,40 @@ std::vector<int64_t> calculate_padded_sizes(
173
173
174
174
utils::uvec3 calculate_image_extents (
175
175
const std::vector<int64_t >& padded_sizes,
176
- const std::vector<int64_t >& axis_mapping ,
176
+ const std::vector<int64_t >& axis_map ,
177
177
const utils::GPUMemoryLayout memory_layout) {
178
178
VK_CHECK_COND (padded_sizes.size () == 4 );
179
- VK_CHECK_COND (axis_mapping .size () == 4 );
179
+ VK_CHECK_COND (axis_map .size () == 4 );
180
180
181
181
utils::uvec3 extents ({1 , 1 , 1 });
182
- // First three elements of axis_mapping indicate which (X,Y,Z) image axis the
182
+ // First three elements of axis_map indicate which (X,Y,Z) image axis the
183
183
// width, height, and channels dim of the tensor maps to.
184
184
for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
185
- const int64_t axis = axis_mapping .at (whcn_dim);
185
+ const int64_t axis = axis_map .at (whcn_dim);
186
186
const int64_t dim = padded_sizes.size () - 1 - whcn_dim;
187
187
extents[axis] = utils::safe_downcast<uint32_t >(padded_sizes.at (dim));
188
188
}
189
189
190
- // axis_mapping [3] indicates the WHCN index of the dimension used for batch
190
+ // axis_map [3] indicates the WHCN index of the dimension used for batch
191
191
// concatenation. Thus a double lookup is required to determine the image axis
192
192
// used for batch concatenation.
193
- const int64_t concatted_whcn_dim = axis_mapping .at (3 );
194
- const int64_t batch_axis = axis_mapping .at (concatted_whcn_dim);
193
+ const int64_t concatted_whcn_dim = axis_map .at (3 );
194
+ const int64_t batch_axis = axis_map .at (concatted_whcn_dim);
195
195
// Multiply the extents of the batch axis by the batch size.
196
196
extents[batch_axis] *= padded_sizes.at (0 );
197
197
198
198
switch (memory_layout) {
199
199
case utils::kWidthPacked :
200
- VK_CHECK_COND (extents[0 ] % 4 == 0 );
201
- extents[0 ] /= 4 ;
200
+ VK_CHECK_COND (extents[axis_map. at ( 0 ) ] % 4 == 0 );
201
+ extents[axis_map. at ( 0 ) ] /= 4 ;
202
202
break ;
203
203
case utils::kHeightPacked :
204
- VK_CHECK_COND (extents[1 ] % 4 == 0 );
205
- extents[1 ] /= 4 ;
204
+ VK_CHECK_COND (extents[axis_map. at ( 1 ) ] % 4 == 0 );
205
+ extents[axis_map. at ( 1 ) ] /= 4 ;
206
206
break ;
207
207
case utils::kChannelsPacked :
208
- VK_CHECK_COND (extents[2 ] % 4 == 0 );
209
- extents[2 ] /= 4 ;
208
+ VK_CHECK_COND (extents[axis_map. at ( 2 ) ] % 4 == 0 );
209
+ extents[axis_map. at ( 2 ) ] /= 4 ;
210
210
break ;
211
211
}
212
212
@@ -229,25 +229,27 @@ vTensor::vTensor(
229
229
// Calculate tensor metadata
230
230
sizes_(sizes.begin(), sizes.end()),
231
231
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232
- axis_mapping_(default_axis_mapping ()),
232
+ axis_map_(default_axis_map ()),
233
233
strides_(calculate_strides(sizes, dim_order_)),
234
234
numel_(utils::multiply_integers(sizes_)),
235
235
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
236
236
unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
237
237
padded_numel_ (utils::multiply_integers(padded_sizes_)),
238
238
texture_limits_{{0 , 0 , 0 }},
239
+ logical_limits_{{0 , 0 , 0 }},
239
240
// Utility Uniform Buffers that can be passed to shaders as arguments
240
241
sizes_uniform_ (),
241
242
strides_uniform_ (),
242
243
numel_uniform_ (),
243
- axis_mapping_uniform_ (),
244
+ axis_map_uniform_ (),
244
245
texture_limits_uniform_ (),
246
+ logical_limits_uniform_ (),
245
247
// Construct Tensor storage
246
248
storage_ (
247
249
context,
248
250
storage_type,
249
251
memory_layout_,
250
- axis_mapping_ ,
252
+ axis_map_ ,
251
253
padded_sizes_,
252
254
dtype_,
253
255
allocate_memory) {
@@ -259,6 +261,8 @@ vTensor::vTensor(
259
261
utils::safe_downcast<int32_t >(storage_.image_extents_ [0 ]),
260
262
utils::safe_downcast<int32_t >(storage_.image_extents_ [1 ]),
261
263
utils::safe_downcast<int32_t >(storage_.image_extents_ [2 ])};
264
+
265
+ update_logical_limits ();
262
266
}
263
267
264
268
if (dtype == vkapi::kHalf ) {
@@ -275,7 +279,7 @@ vTensor::vTensor(const vTensor& other)
275
279
// Copy tensor size metadata
276
280
sizes_(other.sizes_.begin(), other.sizes_.end()),
277
281
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278
- axis_mapping_ (other.axis_mapping_ .begin(), other.axis_mapping_ .end()),
282
+ axis_map_ (other.axis_map_ .begin(), other.axis_map_ .end()),
279
283
strides_(other.strides_.begin(), other.strides_.end()),
280
284
numel_(other.numel_),
281
285
padded_sizes_{other.padded_sizes_ .begin (), other.padded_sizes_ .end ()},
@@ -284,12 +288,14 @@ vTensor::vTensor(const vTensor& other)
284
288
other.unsqueezed_strides_ .end ()},
285
289
padded_numel_ (other.padded_numel_),
286
290
texture_limits_{other.texture_limits_ },
291
+ logical_limits_{other.logical_limits_ },
287
292
// Empty initialize Utility Uniform Buffers
288
293
sizes_uniform_ (),
289
294
strides_uniform_ (),
290
295
numel_uniform_ (),
291
- axis_mapping_uniform_ (),
296
+ axis_map_uniform_ (),
292
297
texture_limits_uniform_ (),
298
+ logical_limits_uniform_ (),
293
299
// Copy Tensor storage
294
300
storage_ (other.storage_) {}
295
301
@@ -303,19 +309,21 @@ vTensor::vTensor(
303
309
// Copy tensor size metadata
304
310
sizes_(sizes.begin(), sizes.end()),
305
311
dim_order_(dim_order.begin(), dim_order.end()),
306
- axis_mapping_(default_axis_mapping ()),
312
+ axis_map_(default_axis_map ()),
307
313
strides_(calculate_strides(sizes_, dim_order_)),
308
314
numel_(utils::multiply_integers(sizes_)),
309
315
padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
310
316
unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
311
317
padded_numel_ (utils::multiply_integers(padded_sizes_)),
312
- texture_limits_{{0 , 0 , 0 }},
318
+ texture_limits_{other.texture_limits_ },
319
+ logical_limits_ (other.logical_limits_),
313
320
// Empty initialize Utility Uniform Buffers
314
321
sizes_uniform_ (),
315
322
strides_uniform_ (),
316
323
numel_uniform_ (),
317
- axis_mapping_uniform_ (),
324
+ axis_map_uniform_ (),
318
325
texture_limits_uniform_ (),
326
+ logical_limits_uniform_ (),
319
327
// Copy Tensor storage
320
328
storage_ (other.storage_, vkapi::element_size(dtype_) * offset_numel) {
321
329
VK_CHECK_COND (
@@ -356,12 +364,18 @@ vkapi::VulkanBuffer& vTensor::buffer(
356
364
return storage_.buffer_ ;
357
365
}
358
366
359
- utils::uvec3 vTensor::mapped_extents () const {
360
- utils::uvec3 m_extents;
361
- m_extents[0 ] = storage_.image_extents_ [axis_mapping_.at (0 )];
362
- m_extents[1 ] = storage_.image_extents_ [axis_mapping_.at (1 )];
363
- m_extents[2 ] = storage_.image_extents_ [axis_mapping_.at (2 )];
364
- return m_extents;
367
+ void vTensor::update_logical_limits () {
368
+ logical_limits_.limits [0 ] = texture_limits_.limits [axis_map_.at (0 )];
369
+ logical_limits_.limits [1 ] = texture_limits_.limits [axis_map_.at (1 )];
370
+ logical_limits_.limits [2 ] = texture_limits_.limits [axis_map_.at (2 )];
371
+ }
372
+
373
+ utils::uvec3 vTensor::logical_extents () const {
374
+ utils::uvec3 logical_extents (
375
+ {utils::safe_downcast<uint32_t >(logical_limits_.limits [0 ]),
376
+ utils::safe_downcast<uint32_t >(logical_limits_.limits [1 ]),
377
+ utils::safe_downcast<uint32_t >(logical_limits_.limits [2 ])});
378
+ return logical_extents;
365
379
}
366
380
367
381
const vkapi::BufferBindInfo vTensor::sizes_ubo () {
@@ -380,12 +394,12 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
380
394
return vkapi::BufferBindInfo (strides_uniform_.buffer ());
381
395
}
382
396
383
- const vkapi::BufferBindInfo vTensor::axis_mapping_ubo () {
384
- if (!axis_mapping_uniform_ .buffer ()) {
385
- axis_mapping_uniform_ =
386
- ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_mapping_ ));
397
+ const vkapi::BufferBindInfo vTensor::axis_map_ubo () {
398
+ if (!axis_map_uniform_ .buffer ()) {
399
+ axis_map_uniform_ =
400
+ ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_map_ ));
387
401
}
388
- return vkapi::BufferBindInfo (axis_mapping_uniform_ .buffer ());
402
+ return vkapi::BufferBindInfo (axis_map_uniform_ .buffer ());
389
403
}
390
404
391
405
const vkapi::BufferBindInfo vTensor::texture_limits_ubo () {
@@ -395,6 +409,13 @@ const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
395
409
return vkapi::BufferBindInfo (texture_limits_uniform_.buffer ());
396
410
}
397
411
412
+ const vkapi::BufferBindInfo vTensor::logical_limits_ubo () {
413
+ if (!logical_limits_uniform_.buffer ()) {
414
+ logical_limits_uniform_ = ParamsBuffer (storage_.context_ , logical_limits_);
415
+ }
416
+ return vkapi::BufferBindInfo (logical_limits_uniform_.buffer ());
417
+ }
418
+
398
419
const vkapi::BufferBindInfo vTensor::numel_ubo () {
399
420
if (!numel_uniform_.buffer ()) {
400
421
numel_uniform_ = ParamsBuffer (storage_.context_ , numel_);
@@ -465,14 +486,16 @@ void vTensor::update_metadata() {
465
486
// Calculate the extents of the image texture that would have been required
466
487
// for a tensor of the new sizes.
467
488
utils::uvec3 virtual_extents =
468
- calculate_image_extents (padded_sizes_, axis_mapping_ , memory_layout_);
489
+ calculate_image_extents (padded_sizes_, axis_map_ , memory_layout_);
469
490
470
491
// Update the texture limits to reflect the new virtual extents.
471
492
texture_limits_.limits = utils::ivec3{
472
493
utils::safe_downcast<int32_t >(virtual_extents[0 ]),
473
494
utils::safe_downcast<int32_t >(virtual_extents[1 ]),
474
495
utils::safe_downcast<int32_t >(virtual_extents[2 ])};
475
496
497
+ update_logical_limits ();
498
+
476
499
if (sizes_uniform_.buffer ()) {
477
500
sizes_uniform_.update (utils::make_whcn_ivec4 (sizes_));
478
501
}
@@ -482,20 +505,23 @@ void vTensor::update_metadata() {
482
505
if (numel_uniform_.buffer ()) {
483
506
numel_uniform_.update (numel_);
484
507
}
485
- if (axis_mapping_uniform_ .buffer ()) {
486
- axis_mapping_uniform_ .update (utils::make_ivec4 (axis_mapping_ ));
508
+ if (axis_map_uniform_ .buffer ()) {
509
+ axis_map_uniform_ .update (utils::make_ivec4 (axis_map_ ));
487
510
}
488
511
if (texture_limits_uniform_.buffer ()) {
489
512
texture_limits_uniform_.update (texture_limits_);
490
513
}
514
+ if (logical_limits_uniform_.buffer ()) {
515
+ logical_limits_uniform_.update (logical_limits_);
516
+ }
491
517
}
492
518
493
519
void vTensor::check_sizes (const std::vector<int64_t >& sizes) const {
494
520
if (storage_type () != utils::kBuffer ) {
495
521
// For texture storage check that the current texture is large enough for
496
522
// the new sizes of the tensor.
497
523
utils::uvec3 virtual_extents =
498
- calculate_image_extents (padded_sizes_, axis_mapping_ , memory_layout_);
524
+ calculate_image_extents (padded_sizes_, axis_map_ , memory_layout_);
499
525
500
526
bool valid_resize = virtual_extents[0 ] <= image_extents ()[0 ];
501
527
valid_resize = valid_resize && virtual_extents[1 ] <= image_extents ()[1 ];
@@ -546,7 +572,7 @@ void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
546
572
update_metadata ();
547
573
storage_.discard_and_reallocate (
548
574
calculate_padded_sizes (new_sizes, memory_layout_),
549
- axis_mapping_ ,
575
+ axis_map_ ,
550
576
memory_layout_,
551
577
dtype_);
552
578
}
@@ -624,16 +650,14 @@ vTensorStorage::vTensorStorage(
624
650
Context* const context,
625
651
const utils::StorageType storage_type,
626
652
const utils::GPUMemoryLayout gpu_memory_layout,
627
- const std::vector<int64_t >& axis_mapping ,
653
+ const std::vector<int64_t >& axis_map ,
628
654
const std::vector<int64_t >& padded_sizes,
629
655
const vkapi::ScalarType dtype,
630
656
const bool allocate_memory)
631
657
: context_(context),
632
658
storage_type_{storage_type},
633
- image_extents_ (calculate_image_extents(
634
- padded_sizes,
635
- axis_mapping,
636
- gpu_memory_layout)),
659
+ image_extents_ (
660
+ calculate_image_extents (padded_sizes, axis_map, gpu_memory_layout)),
637
661
buffer_length_{utils::multiply_integers (padded_sizes)},
638
662
buffer_offset_{0 },
639
663
image_ (allocate_image(
@@ -746,7 +770,7 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
746
770
747
771
void vTensorStorage::discard_and_reallocate (
748
772
const std::vector<int64_t >& padded_sizes,
749
- const std::vector<int64_t >& axis_mapping ,
773
+ const std::vector<int64_t >& axis_map ,
750
774
const utils::GPUMemoryLayout gpu_memory_layout,
751
775
const vkapi::ScalarType dtype) {
752
776
const bool image_owns_memory = image_.owns_memory ();
@@ -755,7 +779,7 @@ void vTensorStorage::discard_and_reallocate(
755
779
flush ();
756
780
757
781
image_extents_ =
758
- calculate_image_extents (padded_sizes, axis_mapping , gpu_memory_layout);
782
+ calculate_image_extents (padded_sizes, axis_map , gpu_memory_layout);
759
783
image_ = allocate_image (
760
784
context_,
761
785
image_extents_,
0 commit comments