Skip to content

Commit 788d339

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Functions for handling zero dim tensor (#693)
Summary: Pull Request resolved: #693 Update `tensor_has_dim` to also account for zero-rank tensors, which allow dim to be 0 or -1. Also introduced two helper functions to `nonzero_dim` and `nonempty_size` that are used to handle zero-rank tensors ghstack-source-id: 203341574 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49735859 fbshipit-source-id: 2ed843a12d922fd78b0f7e67b6ff89cc871298d4
1 parent 006e3b2 commit 788d339

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,26 @@ using ScalarType = exec_aten::ScalarType;
412412
// Utility functions for checking tensor attributes
413413
//
414414

415+
/*
416+
* Returns the tensor's number of dimensions, except when the tensor is zero
417+
* dimensional. In this case, it returns 1. This is used to properly handle
418+
* the zero dimensional tensors in some kernels, that treat them as 1D tensors
419+
* with a single element.
420+
*/
421+
inline ssize_t nonzero_dim(const Tensor& tensor) {
422+
return tensor.dim() == 0 ? 1 : tensor.dim();
423+
}
424+
425+
/*
426+
* Returns the size along a dimension dim, except when the tensor is zero
427+
* dimensional. In this case, it returns 1. This is used to properly handle
428+
* the zero dimensional tensors in some kernels, that treat them as 1D tensors
429+
* with a single element.
430+
*/
431+
inline ssize_t nonempty_size(const Tensor& tensor, ssize_t dim) {
432+
return tensor.dim() == 0 ? 1 : tensor.size(dim);
433+
}
434+
415435
inline bool tensor_can_cast_to(
416436
exec_aten::Tensor a,
417437
exec_aten::ScalarType dtype) {
@@ -516,12 +536,18 @@ inline bool tensor_has_rank_greater_or_equal_to(
516536
}
517537

518538
inline bool tensor_has_dim(exec_aten::Tensor t, int64_t d) {
519-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
520-
d > 0 ? d < t.dim() : t.dim() + d >= 0,
521-
"%zu-dim tensor does not have dim at index %zu",
522-
static_cast<size_t>(t.dim()),
523-
static_cast<size_t>(d));
524-
539+
if (t.dim() == 0) {
540+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
541+
d == 0 || d == -1,
542+
"dim must be 0 or -1 for 0-dim tensor, got %" PRId64,
543+
d);
544+
} else {
545+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
546+
d > 0 ? d < t.dim() : t.dim() + d >= 0,
547+
"%zu-dim tensor does not have dim at index %zu",
548+
static_cast<size_t>(t.dim()),
549+
static_cast<size_t>(d));
550+
}
525551
return true;
526552
}
527553

0 commit comments

Comments
 (0)