-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh, MPI] Mesh2mpi #104566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh, MPI] Mesh2mpi #104566
Conversation
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) ChangesPass for lowering Patch is 28.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104566.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..ad8e98442ab8bc 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..f9a6f52a22c6ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -869,6 +869,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+ let summary = "Convert Mesh dialect to MPI dialect.";
+ let description = [{
+ This pass converts communication operations
+ from the Mesh dialect to operations from the MPI dialect.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "mpi::MPIDialect",
+ "scf::SCFDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..9d1684b78f34f2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -155,6 +155,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
];
}
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary =
+ "For given split axes get the linear index the direct neighbor processes.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.neighbor_linear_index on @mesh for $device
+ split_axes = $split_axes : index
+ ```
+ Given `@mesh` with shape `(10, 20, 30)`,
+ `device` = `(1, 2, 3)`
+ `$split_axes` = `[1]`
+ it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+ and `(1, 3, 3)`: `693`.
+
+ A negative value is returned if `$device` has no neighbor in the given
+ direction along the given `split_axes`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh,
+ Variadic<Index>:$device,
+ Mesh_MeshAxesAttr:$split_axes);
+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+ let assemblyFormat = [{
+ `on` $mesh `[` $device `]`
+ `split_axes` `=` $split_axes
+ attr-dict `:` type(results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Sharding operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..3ee237f4e62acd 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+ MeshToMPI.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRLinalgTransforms
+ MLIRMemRefDialect
+ MLIRPass
+ MLIRMeshDialect
+ MLIRMPIDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..42d885a109ee79
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,225 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). It is assumed that the last dim in a default memref is
+ // contiguous, hence iteration starts with the complete halo on the first
+ // dim which should be contiguous (unless the source is not). The size of
+ // the exchanged data will decrease when iterating over dimensions. That's
+ // good because the halos of last dim will be most fragmented.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. subviews and halos have dynamic and static values, so
+ // OpFoldResults are used whenever possible.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto array = op.getInput();
+ auto rank = array.getType().getRank();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+ op.getDynamicHaloSizes(), rewriter);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = rewriter
+ .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank);
+ // we need the actual shape to compute offsets and sizes
+ for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = rewriter.getIndexAttr(s);
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+ auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+ auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ rewriter.getIndexType());
+ auto myMultiIndex =
+ rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ .getResult();
+ // halo sizes are provided for split dimensions only
+ auto currHaloDim = 0;
+
+ for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+ if (splitAxes.empty()) {
+ continue;
+ }
+ // Get the linearized ids of the neighbors (down and up) for the
+ // given split
+ auto tmp = rewriter
+ .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+ splitAxes)
+ .getResults();
+ // MPI operates on i32...
+ Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[0]),
+ rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[1])};
+ // store for later
+ auto orgDimSize = shape[dim];
+ // this dim's offset to the start of the upper halo
+ auto upperOffset = rewriter.create<arith::SubIOp>(
+ loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
+ // The current approach is by far not optimal, this should be at least
+ // be a red-black pattern or using MPI_sendrecv.
+ // Also, buffers should be re-used.
+ // Still using temporary contiguous buffers for MPI communication...
+ // Still yielding a "serialized" communication pattern...
+ auto genSendRecv = [&](auto dim, bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
+ // Check if we need to send and/or receive
+ // Processes on the mesh borders have only one neighbor
+ auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto hasFrom = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, to, zero);
+ auto buffer = rewriter.create<memref::AllocOp>(
+ loc, shape, array.getType().getElementType());
+ // if has neighbor: copy halo data from array to buffer and send
+ rewriter.create<scf::IfOp>(
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+ : OpFoldResult(upperOffset);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, subview, buffer);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+ builder.create<scf::YieldOp>(loc);
+ });
+ // if has neighbor: receive halo data into buffer and copy to array
+ rewriter.create<scf::IfOp>(
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+ : OpFoldResult(builder.getIndexAttr(0));
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, buffer, subview);
+ builder.create<scf::YieldOp>(loc);
+ });
+ rewriter.create<memref::DeallocOp>(loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ genSendRecv(dim, false);
+ genSendRecv(dim, true);
+
+ // prepare shape and offsets for next split dim
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the shape for next halo excludes the halo on both ends for the
+ // current dim
+ shape[dim] =
+ rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+ .getResult();
+ // the offsets for next halo starts after the down halo for the
+ // current dim
+ offsets[dim] = haloSizes[currHaloDim * 2];
+ // on to next halo
+ ++currHaloDim;
+ }
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
+struct ConvertMeshToMPIPass
+ : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+ using Base::Base;
+
+ /// Run the dialect converter on the module.
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<ConvertUpdateHaloOp>(ctx);
+
+ (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns));
+ }
+};
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+ return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c35020b4c20ccc..f25bbbf8e274b6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -730,6 +730,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
setNameFn(getResult(), "proc_linear_idx");
}
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getNeighborDown(), "down_linear_idx");
+ setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..5f563364272d96
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,173 @@
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+ // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+ // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+ // CHECK-NEXT: scf.if [[v3]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v2]] {
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+ // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+ // CHECK-NEXT: scf.if [[v5]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v4]] {
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+ // CHECK-NEXT: return
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
...
[truncated]
|
@AntonLydike @tkarna Please have a look |
@sogartar @yaochengji , could you take a look at this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, mostly minor remarks about docstrings
using OpRewritePattern::OpRewritePattern; | ||
|
||
mlir::LogicalResult | ||
matchAndRewrite(mlir::mesh::UpdateHaloOp op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution, @fschlimb .
I'm curious as currently all the operations in mesh dialect work on tensor type except update_halo
op. So should we perform bufferization before converting from mesh to mpi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you also make other operations in Mesh dialect support memref type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this work on memref because its semantics are memref-like, e.g. the halos are updated in-place, the input memref gets mutated, it does not return a new tensor or memref. In general we can of course think about adding a updateHalo variant wich operates on tensors and simply returns a new tensor.
Generally we could do the same for other ops, but I thought we should do that once we see need for it. updateHalo
is a very special operation which mostly applies to array computations, and is probably less relevant in the tensor/AI world. Currently I do not see that any of the other operations have memref semantics.
Notice: within spmdization and for updateHalo
specifically, an spmdize
for relevant ops (like an inplace array.insert_slice
) would insert bufferization.to_memref
and bufferization.to_tensor
ops appropriately around updateHalo
. In our (@tkarna) experience this approach works fine (even with one-shot.bufferize) when using restrict=true
in bufferization.to_tensor
.
Wrt to when to apply this pass: generally it is probably a good idea to do the MPI conversion after bufferization. This will require changes to the op spec so that they are allowed to accept memrefs. I don't know enough about bufferization and tensor optimization to tell if converting to MPI right after spmdization (using to_memref/tensor) would disallow any optimization possible otherwise. Any insights are welcome.
I am planning to add send
and recv
in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know enough about bufferization and tensor optimization to tell if converting to MPI right after spmdization (using to_memref/tensor) would disallow any optimization possible otherwise. Any insights are welcome.
In my understanding, usually we prefer to perform optimization on tensor type than memref type because RAW dependency is difficult to detect on memref type.
I made this work on memref because its semantics are memref-like, e.g. the halos are updated in-place
Even if most of case it is updated in place, we could still let it support tensor type and make it "in-place" after bufferization.
Combine the two points together, I would suggest that make updateHalo
support tensor type will make it easier to optimize the IR containing both updateHalo
and other ops in mesh dialect. Because we only need to handle pure tensor types in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, let's add a tensor-based updateHalo
once needed.
We need the memref-based version no matter what. Array/numpy semantics are partially reference-based. For subview
and insert_slice
they have memref
semantics, copies are disallowed. This cannot really be expressed on the tensor-level.
My question was not so much about optimizations in general (which of course are simpler on the tensor level). I was wondering if early, implicit bufferization - when converting to MPI - would do any harm, like when done right before bufferization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the latest version the buffer is now given in destination-passing style and accepts memrefs and tensors. Currently the lowering simply applies bufferization.to_memref
/bufferization.to_tensor
if a tensor is given. Should this crude approach ever be in the way for some optimization pattern, we can adjust accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Frank, thanks for this exciting contribution! This looks to be a very solid first step, although
I can only comment from the MPI perspective, not so much from the mesh dialect. I'm excited to see where we can go from here!
(Sorry for the delay in response, I was on vacation)
} | ||
} | ||
|
||
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nit: Im curious, why not choose a more "canonical" value like 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have only COMM_WORLD this might reduce the risk of tag conflicts (like in multi-threaded cases).
@yaochengji @tkarna @AntonLydike @sogartar Could you approve if you are ok with this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @fschlimb for this contribution! Looks to be good from the MPI side!
✅ With the latest revision this PR passed the C/C++ code formatter. |
I realized that I had to fix a few things in mesh before this can be useful. These fixes have been merged now (#114238).
@yaochengji @tkarna @AntonLydike @sogartar @mfrancio could you have a look (again) please? |
namespace { | ||
// Create operations converting a linear index to a multi-dimensional index | ||
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, | ||
Value linearIndex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert that linearIndex
is IndexType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Values in dimensions
must be of same integer type as linearIndex
.
return linearIndex; | ||
} | ||
|
||
// This pattern converts the mesh.update_halo operation to MPI calls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better docstring: this pattern just converts the process index. Similar issue with the following patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Removed meaningless comments and added more useful ones below.
@yaochengji @AntonLydike @sogartar @mfrancio could you have a look please? |
LGTM, thanks |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/10431 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/9141 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/11542 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/6991 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/11913 Here is the relevant piece of the build log for the reference
|
Working on the post-commit failures. See #117986. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/7072 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/4203 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/16057 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/17/builds/4214 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/143/builds/3755 Here is the relevant piece of the build log for the reference
|
fixing post-CI failures #104566
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/6814 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/5827 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/80/builds/6949 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/161/builds/3413 Here is the relevant piece of the build log for the reference
|
Pass for lowering
Mesh
toMPI
.Initial commit lowers
UpdateHaloOp
only.