Skip to content

Commit b40f420

Browse files
author
Mahesh Ravishankar
committed
[mlir][MemRef] Add early exit for computing dropped unit-dims.
Computing dropped unit-dims when all the unit dims are dropped, does not need to check for strides being dropped. This also enables canonicalization of reduced-rank subviews. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D121766
1 parent 5b81158 commit b40f420

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,11 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
759759
if (attr.cast<IntegerAttr>().getInt() == 1)
760760
unusedDims.set(dim.index());
761761

762+
// Early exit for the case where the number of unused dims matches the number
763+
// of ranks reduced.
764+
if (unusedDims.count() + reducedType.getRank() == originalType.getRank())
765+
return unusedDims;
766+
762767
SmallVector<int64_t> originalStrides, candidateStrides;
763768
int64_t originalOffset, candidateOffset;
764769
if (failed(

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,3 +719,19 @@ func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index)
719719
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
720720
return %1 : memref<?xi8>
721721
}
722+
723+
// -----
724+
725+
func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
726+
%arg1 : index) -> memref<?xf32, offset : ?, strides : [?]> {
727+
%c0 = arith.constant 0 : index
728+
%c1 = arith.constant 1 : index
729+
%0 = memref.subview %arg0[%c0, %c0] [1, %arg1] [%c1, %c1] : memref<8x?xf32> to memref<?xf32, offset : ?, strides : [?]>
730+
return %0 : memref<?xf32, offset : ?, strides : [?]>
731+
}
732+
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
733+
// CHECK: func @canonicalize_rank_reduced_subview
734+
// CHECK-SAME: %[[ARG0:.+]]: memref<8x?xf32>
735+
// CHECK-SAME: %[[ARG1:.+]]: index
736+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1]
737+
// CHECK-SAME: memref<8x?xf32> to memref<?xf32, #[[MAP]]>

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -677,15 +677,6 @@ func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg
677677

678678
// -----
679679

680-
func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
681-
%arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
682-
// expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}}
683-
%0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
684-
return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
685-
}
686-
687-
// -----
688-
689680
#map0 = affine_map<(d0, d1)[s0] -> (d0 * 16 + d1)>
690681

691682
func @subview_bad_offset_1(%arg0: memref<16x16xf32>) {

0 commit comments

Comments
 (0)