Skip to content

Commit 586a22a

Browse files
committed
[ET-VK] Fix negative dim in normalize_to_dim_index
Our `cat` implementation currently fails for negative dim. Differential Revision: [D62270925](https://our.internmc.facebook.com/intern/diff/D62270925/) ghstack-source-id: 241265784 Pull Request resolved: #5118
1 parent 83d92ff commit 586a22a

File tree

1 file changed

+2
-1
lines changed
  • backends/vulkan/runtime/graph/ops/impl/utils

1 file changed

+2
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ constexpr DimIndex kChannel4D = DimIndex::DIM_3RD_LAST;
3232
constexpr DimIndex kBatch4D = DimIndex::DIM_4TH_LAST;
3333

3434
inline DimIndex normalize_to_dim_index(const api::vTensor& v_in, int32_t dim) {
35-
return static_cast<DimIndex>(dim - v_in.dim());
35+
return dim < 0 ? static_cast<DimIndex>(dim)
36+
: static_cast<DimIndex>(dim - v_in.dim());
3637
}
3738

3839
/*

0 commit comments

Comments
 (0)