Skip to content

Commit d4ae7ee

Browse files
authored
[mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations (#68522)
1 parent 97b989b commit d4ae7ee

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
13881388
return operand;
13891389

13901390
// Insert a reshape to collapse the dimensions.
1391-
auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
1392-
loc, operand, operandReassociation);
1393-
return reshapeOp.getResult();
1391+
if (isa<MemRefType>(operand.getType())) {
1392+
return builder
1393+
.create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1394+
.getResult();
1395+
} else {
1396+
return builder
1397+
.create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1398+
.getResult();
1399+
}
13941400
}
13951401

13961402
/// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
14441450
}))
14451451
return failure();
14461452

1453+
bool hasBufferSemantics = genericOp.hasBufferSemantics();
1454+
if (hasBufferSemantics &&
1455+
!llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
1456+
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1457+
if (!memRefToCollapse)
1458+
return true;
1459+
1460+
return memref::CollapseShapeOp::isGuaranteedCollapsible(
1461+
memRefToCollapse, foldedIterationDims);
1462+
}))
1463+
return rewriter.notifyMatchFailure(genericOp,
1464+
"memref is not guaranteed collapsible");
1465+
14471466
CollapsingInfo collapsingInfo;
14481467
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
14491468
foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
14991518
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
15001519
collapsingInfo, rewriter);
15011520
outputOperands.push_back(newOutput);
1502-
resultTypes.push_back(newOutput.getType());
1521+
// If the op has "buffer semantics", then the init operands are ranked
1522+
// memrefs and the op has no results.
1523+
if (!hasBufferSemantics)
1524+
resultTypes.push_back(newOutput.getType());
15031525
}
15041526

15051527
// Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
15381560
genericOp.getIndexingMapMatchingResult(originalResult.value());
15391561
SmallVector<ReassociationIndices> reassociation =
15401562
getOperandReassociation(indexingMap, collapsingInfo);
1541-
Value result = rewriter.create<tensor::ExpandShapeOp>(
1542-
loc, originalResultType, collapsedOpResult, reassociation);
1543-
results.push_back(result);
1563+
if (isa<MemRefType>(collapsedOpResult.getType())) {
1564+
Value result = rewriter.create<memref::ExpandShapeOp>(
1565+
loc, originalResultType, collapsedOpResult, reassociation);
1566+
results.push_back(result);
1567+
} else {
1568+
Value result = rewriter.create<tensor::ExpandShapeOp>(
1569+
loc, originalResultType, collapsedOpResult, reassociation);
1570+
results.push_back(result);
1571+
}
15441572
} else {
15451573
results.push_back(collapsedOpResult);
15461574
}

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
7070
// CHECK-LABEL: func @uncollapsable(
7171
// CHECK: linalg.generic
7272
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
73+
74+
// -----
75+
76+
// CHECK-LABEL: func.func private @collapsable_memref(
77+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
78+
// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
79+
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
80+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
81+
// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
82+
// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
83+
// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
84+
// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
85+
// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
86+
// CHECK: linalg.yield %[[VAL_9]] : f32
87+
// CHECK: }
88+
// CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
89+
// CHECK: }
90+
91+
func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
92+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
93+
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
94+
^bb0(%in: f32, %in_0: f32, %out: f32):
95+
%0 = arith.addf %in, %in_0 : f32
96+
linalg.yield %0 : f32
97+
}
98+
return %alloc : memref<1x24x32x8xf32>
99+
}
100+
101+
// -----
102+
103+
// CHECK-LABEL: func @uncollapsable_strided_memref(
104+
// CHECK: linalg.generic
105+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
106+
107+
func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
108+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
109+
%subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
110+
%subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
111+
%subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
112+
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
113+
^bb0(%in: i32, %in_0: i32, %out: i32):
114+
%0 = arith.addi %in, %in_0 : i32
115+
linalg.yield %0 : i32
116+
}
117+
return %alloc : memref<2x6x24x48xi32>
118+
}

0 commit comments

Comments
 (0)