-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations #68522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations #68522
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Changesoperations Full diff: https://github.com/llvm/llvm-project/pull/68522.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..6f4b0ff60ca97c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
return operand;
// Insert a reshape to collapse the dimensions.
- auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
- loc, operand, operandReassociation);
- return reshapeOp.getResult();
+ if (isa<MemRefType>(operand.getType())) {
+ return builder
+ .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ } else {
+ return builder
+ .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ }
}
/// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
}))
return failure();
+ bool hasBufferSemantics = genericOp.hasBufferSemantics();
+ if (hasBufferSemantics &&
+ !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+ MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+ if (!memRefToCollapse)
+ return true;
+
+ return memref::CollapseShapeOp::isGuaranteedCollapsible(
+ memRefToCollapse, foldedIterationDims);
+ }))
+ return rewriter.notifyMatchFailure(genericOp,
+ "memref is not guaranteed collapsible");
+
CollapsingInfo collapsingInfo;
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
- resultTypes.push_back(newOutput.getType());
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!hasBufferSemantics)
+ resultTypes.push_back(newOutput.getType());
}
// Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
genericOp.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- Value result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ if (isa<MemRefType>(collapsedOpResult.getType())) {
+ Value result = rewriter.create<memref::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ } else {
+ Value result = rewriter.create<tensor::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ }
} else {
results.push_back(collapsedOpResult);
}
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..106154ba3a553bd 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
// CHECK-LABEL: func @uncollapsable(
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL: func.func private @collapsable_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_9]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK: }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %0 = arith.addi %in, %in_0 : i32
+ linalg.yield %0 : i32
+ }
+ return %alloc : memref<2x6x24x48xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. It seems like I had left comments earlier on a similar PR. I am not sure what happened here.
@MaheshRavishankar thanks for reviewing. Indeed there are 2 PRs. The second one is based on this one and I didn't know how to split. I will address the comments in the second PR (#68526) also. |
5e59fd1
to
37c2e56
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, traveling this week.
LG, thanks!
operations