Skip to content

Commit 1ad7725

Browse files
committed
canonicalizing send and recv towrads static memref shapes
1 parent b5013d0 commit 1ad7725

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
8484
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
8585
"type($ref) `,` type($tag) `,` type($rank)"
8686
"(`->` type($retval)^)?";
87+
let hasCanonicalizer = 1;
8788
}
8889

8990
//===----------------------------------------------------------------------===//
@@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
114115
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
115116
"type($ref) `,` type($tag) `,` type($rank)"
116117
"(`->` type($retval)^)?";
118+
let hasCanonicalizer = 1;
117119
}
118120

119121

mlir/lib/Dialect/MPI/IR/MPIOps.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,52 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/MPI/IR/MPI.h"
10+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1011
#include "mlir/IR/Builders.h"
1112
#include "mlir/IR/BuiltinAttributes.h"
13+
#include "mlir/IR/PatternMatch.h"
1214

1315
using namespace mlir;
1416
using namespace mlir::mpi;
1517

18+
namespace {
19+
20+
// If input memref has dynamic shape and is a cast and if the cast's input has
21+
// static shape, fold the cast's static input into the given operation.
22+
template <typename OpT>
23+
struct FoldCast final : public mlir::OpRewritePattern<OpT> {
24+
using mlir::OpRewritePattern<OpT>::OpRewritePattern;
25+
26+
LogicalResult matchAndRewrite(OpT op,
27+
mlir::PatternRewriter &b) const override {
28+
auto mRef = op.getRef();
29+
if (mRef.getType().hasStaticShape()) {
30+
return mlir::failure();
31+
}
32+
auto defOp = mRef.getDefiningOp();
33+
if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
34+
return mlir::failure();
35+
}
36+
auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
37+
if (!src.getType().hasStaticShape()) {
38+
return mlir::failure();
39+
}
40+
op.getRefMutable().assign(src);
41+
return mlir::success();
42+
}
43+
};
44+
} // namespace
45+
46+
void mlir::mpi::SendOp::getCanonicalizationPatterns(
47+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
48+
results.add<FoldCast<mlir::mpi::SendOp>>(context);
49+
}
50+
51+
void mlir::mpi::RecvOp::getCanonicalizationPatterns(
52+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
53+
results.add<FoldCast<mlir::mpi::RecvOp>>(context);
54+
}
55+
1656
//===----------------------------------------------------------------------===//
1757
// TableGen'd op method definitions
1858
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,34 +115,30 @@ func.func @update_halo_3d(
115115
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
116116
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
117117
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
118-
// CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
119118
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
120119
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
121-
// CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
122-
// CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
120+
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
121+
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
123122
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
124123
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
125124
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
126125
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
127-
// CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
128126
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
129127
// CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
130-
// CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
131-
// CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
128+
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
129+
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
132130
// CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
133131
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
134132
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
135133
// CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
136-
// CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
137-
// CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
134+
// CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
138135
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
139136
// CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
140137
// CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
141138
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
142-
// CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
143139
// CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
144140
// CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
145-
// CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
141+
// CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
146142
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
147143
// CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
148144
// CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -170,34 +166,30 @@ func.func @update_halo_3d_tensor(
170166
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
171167
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8>
172168
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
173-
// CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
174169
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
175170
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
176-
// CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
177-
// CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
171+
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
172+
// CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
178173
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
179174
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
180175
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
181176
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
182-
// CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
183177
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
184178
// CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
185-
// CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
186-
// CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
179+
// CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
180+
// CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
187181
// CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
188182
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
189183
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
190184
// CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
191-
// CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
192-
// CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
185+
// CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
193186
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
194187
// CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
195188
// CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
196189
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
197-
// CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
198190
// CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
199191
// CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
200-
// CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
192+
// CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
201193
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
202194
// CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
203195
// CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -209,7 +201,7 @@ func.func @update_halo_3d_tensor(
209201
// CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
210202
// CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
211203
// CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
212-
// CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8>
204+
// CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
213205
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
214206
// CHECK: return [[v1]] : tensor<120x120x120xi8>
215207
return %res : tensor<120x120x120xi8>

0 commit comments

Comments
 (0)