Skip to content

Commit 3428e50

Browse files
committed
using DLTI instead of global symbol for static rank in comm_world
1 parent 3e7d98d commit 3428e50

File tree

4 files changed

+42
-39
lines changed

4 files changed

+42
-39
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
881881
let description = [{
882882
This pass converts communication operations from the Mesh dialect to the
883883
MPI dialect.
884-
If it finds a global named "static_mpi_rank" it will use that splat value
885-
instead of calling MPI_Comm_rank. This allows optimizations like constant
886-
shape propagation and fusion because shard/partition sizes depend on the
887-
rank.
884+
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
885+
use that integer value instead of calling MPI_Comm_rank. This allows
886+
optimizations like constant shape propagation and fusion because
887+
shard/partition sizes depend on the rank.
888888
}];
889889
let dependentDialects = [
890890
"memref::MemRefDialect",

mlir/lib/Conversion/MeshToMPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
1111
Core
1212

1313
LINK_LIBS PUBLIC
14+
MLIRDLTIDialect
1415
MLIRFuncDialect
1516
MLIRIR
1617
MLIRLinalgTransforms

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -284,32 +284,33 @@ struct ConvertProcessMultiIndexOp
284284
}
285285

286286
rewriter.replaceOp(op, mIdx);
287-
return mlir::success();
287+
return success();
288288
}
289289
};
290290

291-
struct ConvertProcessLinearIndexOp
292-
: public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
293-
using OpRewritePattern::OpRewritePattern;
291+
class ConvertProcessLinearIndexOp
292+
: public OpConversionPattern<ProcessLinearIndexOp> {
293+
int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
294294

295-
mlir::LogicalResult
296-
matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
297-
mlir::PatternRewriter &rewriter) const override {
295+
public:
296+
using OpConversionPattern::OpConversionPattern;
298297

299-
// Finds a global named "static_mpi_rank" it will use that splat value.
300-
// Otherwise it defaults to mpi.comm_rank.
298+
// Constructor accepting worldRank
299+
ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
300+
MLIRContext *context, int64_t worldRank_ = -1)
301+
: OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
301302

302-
auto loc = op.getLoc();
303-
auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
304-
if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
305-
op, rankOpName)) {
306-
if (auto initTnsr = globalOp.getInitialValueAttr()) {
307-
auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
308-
rewriter.replaceOp(op,
309-
rewriter.create<arith::ConstantIndexOp>(loc, val));
310-
return mlir::success();
311-
}
303+
LogicalResult
304+
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
305+
ConversionPatternRewriter &rewriter) const override {
306+
307+
Location loc = op.getLoc();
308+
if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
309+
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
310+
return success();
312311
}
312+
313+
// Otherwise call create mpi::CommRankOp
313314
auto rank =
314315
rewriter
315316
.create<mpi::CommRankOp>(

mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
6060

6161
// -----
6262
// CHECK: mesh.mesh @mesh0
63-
mesh.mesh @mesh0(shape = 3x4x5)
64-
memref.global constant @static_mpi_rank : memref<index> = dense<24>
65-
func.func @process_multi_index() -> (index, index, index) {
66-
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
67-
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
68-
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
69-
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
70-
// CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
71-
return %0#0, %0#1, %0#2 : index, index, index
72-
}
63+
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
64+
mesh.mesh @mesh0(shape = 3x4x5)
65+
func.func @process_multi_index() -> (index, index, index) {
66+
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
67+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
68+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
69+
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
70+
// CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
71+
return %0#0, %0#1, %0#2 : index, index, index
72+
}
7373

74-
// CHECK-LABEL: func @process_linear_index
75-
func.func @process_linear_index() -> index {
76-
// CHECK: %[[c24:.*]] = arith.constant 24 : index
77-
%0 = mesh.process_linear_index on @mesh0 : index
78-
// CHECK: return %[[c24]] : index
79-
return %0 : index
74+
// CHECK-LABEL: func @process_linear_index
75+
func.func @process_linear_index() -> index {
76+
// CHECK: %[[c24:.*]] = arith.constant 24 : index
77+
%0 = mesh.process_linear_index on @mesh0 : index
78+
// CHECK: return %[[c24]] : index
79+
return %0 : index
80+
}
8081
}
8182

8283
// -----

0 commit comments

Comments
 (0)