Skip to content

Commit 891521a

Browse files
authored
[ET-VK] Use dim order as the source of truth for tensor strides
Differential Revision: D61666464 Pull Request resolved: #4844
1 parent d7c069f commit 891521a

File tree

5 files changed

+298
-93
lines changed

5 files changed

+298
-93
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 140 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,15 @@
1313
namespace vkcompute {
1414
namespace api {
1515

16-
/*
17-
* Given the strides of a buffer-backed tensor, find the index of the "fastest
18-
* moving" dimension in WHCN dimension order. If multiple dims have the lowest
19-
* stride, then the "earlier" dim is assumed to be the fastest moving (width is
20-
* "earlier" than height).
21-
*/
22-
int32_t find_fastest_whcn_dim(const std::vector<int64_t>& strides) {
23-
if (strides.size() == 0) {
24-
return 0;
25-
}
26-
int32_t fastest_dim = 0;
27-
int64_t min_stride = strides.at(0);
28-
for (int d = strides.size() - 1; d >= 0; --d) {
29-
if (strides.at(d) < min_stride) {
30-
fastest_dim = d;
31-
min_stride = strides.at(d);
32-
}
33-
}
34-
return (strides.size() - 1 - fastest_dim);
35-
}
36-
3716
/*
3817
* Given the strides of a buffer-backed tensor, estimate the equivalent memory
3918
* layout enum value by identifying the fastest moving dimension.
4019
*/
4120
utils::GPUMemoryLayout estimate_memory_layout(
42-
const std::vector<int64_t>& strides) {
43-
int32_t fastest_dim = find_fastest_whcn_dim(strides);
44-
if (fastest_dim <= 3) {
45-
return utils::GPUMemoryLayout(fastest_dim);
21+
const std::vector<int64_t>& dim_order) {
22+
int64_t fastest_dim_whcn = dim_order.size() - 1 - dim_order.back();
23+
if (fastest_dim_whcn >= 0 && fastest_dim_whcn < 3) {
24+
return utils::GPUMemoryLayout(fastest_dim_whcn);
4625
}
4726

4827
// TODO(ssjia) find a way to gracefully recover from this case by i.e. adding
@@ -51,41 +30,70 @@ utils::GPUMemoryLayout estimate_memory_layout(
5130
VK_THROW("No compatible GPUMemoryLayout value");
5231
}
5332

33+
std::vector<int64_t> calculate_dim_order(
34+
const size_t ndim,
35+
const utils::GPUMemoryLayout memory_layout) {
36+
// Special case for zero dim tensors
37+
if (ndim == 0) {
38+
return {0};
39+
}
40+
std::vector<int64_t> dim_order(ndim);
41+
int64_t last_dim =
42+
ndim - utils::to_packed_dim_nchw_offset<int64_t>(memory_layout);
43+
44+
int64_t cur_dim = 0;
45+
for (int d = 0; d < ndim; ++d) {
46+
if (d == last_dim) {
47+
cur_dim++;
48+
}
49+
dim_order[d] = cur_dim;
50+
cur_dim++;
51+
}
52+
if (last_dim >= 0) {
53+
dim_order[ndim - 1] = last_dim;
54+
}
55+
56+
return dim_order;
57+
}
58+
5459
std::vector<int64_t> calculate_strides(
5560
const std::vector<int64_t>& sizes,
56-
const utils::GPUMemoryLayout memory_layout) {
61+
const std::vector<int64_t>& dim_order) {
5762
// For zero dim tensors
5863
if (sizes.size() == 0) {
5964
return {1};
6065
}
6166

62-
const int64_t dim_offset =
63-
utils::to_packed_dim_nchw_offset<int64_t>(memory_layout);
64-
int64_t last_dim = sizes.size() - dim_offset;
65-
if (last_dim < 0) {
66-
last_dim = sizes.size() - 1;
67-
}
68-
6967
size_t ndim = sizes.size();
7068
std::vector<int64_t> strides(ndim);
7169

72-
const int64_t last_dim_size = sizes.at(last_dim);
73-
74-
for (int stride_d = ndim - 1; stride_d >= 0; stride_d--) {
75-
strides.at(stride_d) = 1;
76-
if (stride_d == last_dim) {
77-
continue;
78-
}
79-
strides.at(stride_d) = last_dim_size;
80-
for (int size_d = ndim - 1; size_d > stride_d; size_d--) {
81-
if (size_d != last_dim) {
82-
strides.at(stride_d) *= sizes.at(size_d);
83-
}
70+
strides[dim_order[ndim - 1]] = 1;
71+
for (int32_t i = ndim - 2; i >= 0; --i) {
72+
if (sizes[dim_order[i + 1]] == 0) {
73+
strides[dim_order[i]] = strides[dim_order[i + 1]];
74+
} else {
75+
strides[dim_order[i]] =
76+
strides[dim_order[i + 1]] * sizes[dim_order[i + 1]];
8477
}
8578
}
79+
8680
return strides;
8781
}
8882

83+
bool dim_order_is_valid(const std::vector<int64_t>& dim_order) {
84+
int64_t sum = 0;
85+
for (size_t i = 0; i < dim_order.size(); ++i) {
86+
if (dim_order[i] < 0 || dim_order[i] >= dim_order.size()) {
87+
return false;
88+
}
89+
sum += dim_order[i];
90+
}
91+
int64_t n = static_cast<int64_t>(dim_order.size() - 1);
92+
// Sanity check that the sum of the indices in the vector is equal to the sum
93+
// of 0 + 1 + 2 + ... + (ndim - 1)
94+
return sum == n * (n + 1) / 2;
95+
}
96+
8997
std::vector<int64_t> unsqueeze_strides(
9098
const std::vector<int64_t>& strides,
9199
const int64_t numel) {
@@ -170,7 +178,8 @@ vTensor::vTensor(
170178
memory_layout_(memory_layout),
171179
// Calculate tensor size metadata
172180
sizes_(sizes.begin(), sizes.end()),
173-
strides_(calculate_strides(sizes, memory_layout_)),
181+
dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
182+
strides_(calculate_strides(sizes, dim_order_)),
174183
numel_(utils::multiply_integers(sizes_)),
175184
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
176185
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
@@ -189,6 +198,9 @@ vTensor::vTensor(
189198
padded_sizes_,
190199
dtype_,
191200
allocate_memory) {
201+
VK_CHECK_COND(
202+
dim_order_is_valid(dim_order_), "computed dim order is invalid");
203+
192204
if (storage_type != utils::kBuffer) {
193205
texture_limits_.limits = utils::ivec3{
194206
utils::safe_downcast<int32_t>(storage_.image_extents_[0]),
@@ -204,16 +216,39 @@ vTensor::vTensor(
204216
}
205217
}
206218

219+
vTensor::vTensor(const vTensor& other)
220+
: dtype_(other.dtype_),
221+
memory_layout_(other.memory_layout_),
222+
// Copy tensor size metadata
223+
sizes_(other.sizes_.begin(), other.sizes_.end()),
224+
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
225+
strides_(other.strides_.begin(), other.strides_.end()),
226+
numel_(other.numel_),
227+
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
228+
unsqueezed_strides_{
229+
other.unsqueezed_strides_.begin(),
230+
other.unsqueezed_strides_.end()},
231+
padded_numel_(other.padded_numel_),
232+
texture_limits_{other.texture_limits_},
233+
// Empty initialize Utility Uniform Buffers
234+
sizes_uniform_(),
235+
strides_uniform_(),
236+
numel_uniform_(),
237+
texture_limits_uniform_(),
238+
// Copy Tensor storage
239+
storage_(other.storage_) {}
240+
207241
vTensor::vTensor(
208242
const vTensor& other,
209243
const std::vector<int64_t>& sizes,
210-
const std::vector<int64_t>& strides,
211-
const size_t offset_numel)
244+
const std::vector<int64_t>& dim_order,
245+
const int64_t offset_numel)
212246
: dtype_(other.dtype_),
213-
memory_layout_(estimate_memory_layout(strides)),
247+
memory_layout_(estimate_memory_layout(dim_order)),
214248
// Copy tensor size metadata
215249
sizes_(sizes.begin(), sizes.end()),
216-
strides_(strides.begin(), strides.end()),
250+
dim_order_(dim_order.begin(), dim_order.end()),
251+
strides_(calculate_strides(sizes_, dim_order_)),
217252
numel_(utils::multiply_integers(sizes_)),
218253
padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)},
219254
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
@@ -226,6 +261,8 @@ vTensor::vTensor(
226261
texture_limits_uniform_(),
227262
// Copy Tensor storage
228263
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
264+
VK_CHECK_COND(
265+
dim_order_is_valid(dim_order_), "new dim order provided is invalid");
229266
VK_CHECK_COND(
230267
offset_numel + numel_ <= other.numel(),
231268
"Tensor alias cannot access more elements than available in the original"
@@ -339,9 +376,17 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
339376
}
340377
}
341378

342-
void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) {
379+
void vTensor::update_metadata(
380+
const std::vector<int64_t>& new_sizes,
381+
const std::vector<int64_t>& new_dim_order) {
343382
sizes_ = new_sizes;
344-
strides_ = calculate_strides(new_sizes, memory_layout_);
383+
dim_order_ = new_dim_order;
384+
strides_ = calculate_strides(sizes_, dim_order_);
385+
// Only update the memory layout for buffer-backed tensors. Strides are
386+
// meaningless for texture-backed tensors and do not impact the memory layout.
387+
if (storage_type() == utils::kBuffer) {
388+
memory_layout_ = estimate_memory_layout(dim_order_);
389+
}
345390
numel_ = utils::multiply_integers(sizes_);
346391

347392
padded_sizes_ = calculate_padded_sizes(sizes_, memory_layout_);
@@ -373,15 +418,7 @@ void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) {
373418
}
374419
}
375420

376-
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
377-
update_size_metadata(new_sizes);
378-
storage_.discard_and_reallocate(
379-
calculate_padded_sizes(new_sizes, memory_layout_),
380-
memory_layout_,
381-
dtype_);
382-
}
383-
384-
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
421+
void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
385422
if (storage_type() != utils::kBuffer) {
386423
// For texture storage check that the current texture is large enough for
387424
// the new sizes of the tensor.
@@ -394,10 +431,47 @@ void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
394431

395432
VK_CHECK_COND(
396433
valid_resize,
397-
"Cannot use virtual resize if new sizes requires a larger texture.");
434+
"tensor sizes requires a larger texture than the current one.");
435+
} else {
436+
// For buffer storage check that the current buffer is large enough for the
437+
// new sizes of the tensor.
438+
int64_t numel = utils::multiply_integers(sizes);
439+
bool valid_resize =
440+
numel + storage_.buffer_offset_ <= storage_.buffer_length_;
441+
VK_CHECK_COND(
442+
valid_resize,
443+
"tensor sizes requires a larger buffer than the current one.");
398444
}
445+
}
446+
447+
void vTensor::virtual_reconfigure(
448+
const std::vector<int64_t>& new_sizes,
449+
const std::vector<int64_t>& new_dim_order) {
450+
VK_CHECK_COND(
451+
storage_type() == utils::kBuffer,
452+
"virtual_reconfigure is only applicable for buffer backed tensors");
453+
VK_CHECK_COND(new_sizes.size() == new_dim_order.size());
454+
VK_CHECK_COND(dim_order_is_valid(new_dim_order));
399455

400-
update_size_metadata(new_sizes);
456+
check_sizes(new_sizes);
457+
update_metadata(new_sizes, new_dim_order);
458+
}
459+
460+
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
461+
VK_CHECK_COND(
462+
new_sizes.size() == dim_order_.size(),
463+
"new sizes cannot modify the dimensionality of the tensor ");
464+
465+
check_sizes(new_sizes);
466+
update_metadata(new_sizes, dim_order_);
467+
}
468+
469+
void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
470+
update_metadata(new_sizes, dim_order_);
471+
storage_.discard_and_reallocate(
472+
calculate_padded_sizes(new_sizes, memory_layout_),
473+
memory_layout_,
474+
dtype_);
401475
}
402476

403477
//
@@ -480,6 +554,7 @@ vTensorStorage::vTensorStorage(
480554
storage_type_{storage_type},
481555
image_extents_(calculate_image_extents(padded_sizes, gpu_memory_layout)),
482556
buffer_length_{utils::multiply_integers(padded_sizes)},
557+
buffer_offset_{0},
483558
image_(allocate_image(
484559
context_,
485560
image_extents_,
@@ -496,11 +571,12 @@ vTensorStorage::vTensorStorage(
496571

497572
vTensorStorage::vTensorStorage(
498573
const vTensorStorage& other,
499-
const size_t buffer_offset)
574+
const int64_t buffer_offset)
500575
: context_(other.context_),
501576
storage_type_{other.storage_type_},
502577
image_extents_(other.image_extents_),
503578
buffer_length_{other.buffer_length_},
579+
buffer_offset_{buffer_offset},
504580
image_(),
505581
buffer_(other.buffer_, buffer_offset),
506582
last_access_{other.last_access_} {

0 commit comments

Comments
 (0)