Skip to content

Commit 7063c94

Browse files
authored
[mlir][Linalg] Bugfix for folder of linalg.transpose (#102888)
Folder of linalg transpose only support tensor type. Fix #102576.
1 parent 9fa2386 commit 7063c94

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,10 @@ void TransposeOp::getEffects(
19081908

19091909
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
19101910
SmallVectorImpl<OpFoldResult> &result) {
1911+
// Only the tensor type is supported.
1912+
if (!isa<TensorType>(getInput().getType()))
1913+
return failure();
1914+
19111915
// Single dimension transpose.
19121916
if (getPermutation().size() == 0) {
19131917
result.push_back(getInput());

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,3 +1216,19 @@ func.func @concats_of_fill(
12161216
// CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
12171217
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
12181218
// CHECK: return %[[FILL]]
1219+
1220+
// -----
1221+
1222+
func.func @transpose_buffer(%input: memref<?xf32>,
1223+
%init: memref<?xf32>) {
1224+
linalg.transpose ins(%input:memref<?xf32>)
1225+
outs(%init:memref<?xf32>)
1226+
permutation = [0]
1227+
func.return
1228+
}
1229+
1230+
// CHECK-LABEL: func.func @transpose_buffer(
1231+
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
1232+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
1233+
// CHECK: linalg.transpose ins(%[[VAL_0]] : memref<?xf32>)
1234+
// CHECK-SAME: outs(%[[VAL_1]] : memref<?xf32>) permutation = [0]

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,3 +873,39 @@ func.func @lower_to_loops_with_rank_reducing_subviews(
873873
// CHECKPARALLEL: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
874874
// CHECKPARALLEL: memref.store %[[VAL]], %{{.+}}[%[[IV]]]
875875
// CHECKPARALLEL: }
876+
877+
// -----
878+
879+
func.func @transpose(%input: memref<?xf32>,
880+
%init: memref<?xf32>) {
881+
linalg.transpose ins(%input:memref<?xf32>)
882+
outs(%init:memref<?xf32>)
883+
permutation = [0]
884+
return
885+
}
886+
// CHECK-LABEL: func.func @transpose(
887+
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
888+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
889+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
890+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
891+
// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
892+
// CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_2]] {
893+
// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
894+
// CHECK: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
895+
// CHECK: }
896+
// CHECK: return
897+
// CHECK: }
898+
899+
// CHECKPARALLEL-LABEL: func.func @transpose(
900+
// CHECKPARALLEL-SAME: %[[VAL_0:.*]]: memref<?xf32>,
901+
// CHECKPARALLEL-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
902+
// CHECKPARALLEL: %[[VAL_2:.*]] = arith.constant 1 : index
903+
// CHECKPARALLEL: %[[VAL_3:.*]] = arith.constant 0 : index
904+
// CHECKPARALLEL: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
905+
// CHECKPARALLEL: scf.parallel (%[[VAL_5:.*]]) = (%[[VAL_3]]) to (%[[VAL_4]]) step (%[[VAL_2]]) {
906+
// CHECKPARALLEL: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
907+
// CHECKPARALLEL: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
908+
// CHECKPARALLEL: scf.reduce
909+
// CHECKPARALLEL: }
910+
// CHECKPARALLEL: return
911+
// CHECKPARALLEL: }

0 commit comments

Comments
 (0)