Skip to content

[mlir][mpi] Lowering Mpi To LLVM #127053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 21, 2025
Merged

[mlir][mpi] Lowering Mpi To LLVM #127053

merged 18 commits into from
Feb 21, 2025

Conversation

fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Feb 13, 2025

As agreed by @AntonLydike this replaces #95524.

  • The first set of patterns to convert the MPI dialect to LLVM.
  • Further conversion pattern will be added in future PRs.

This adds the following on top of #95524

  • Support for Intel MPI by introducing MPIImplTraits to distinguish MPI implementations
  • Works end-to-end in a downstream project using Intel MPI (going through Mesh dialect)
  • Lowering MPI_Send/MPI_Recv
  • Conversion will only be included if MPI is found

cc @mofeing @wsmoses @tobiasgrosser @hhkit @Groverkss @sjw36 @joker-eph @zero9178

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

As agreed by @AntonLydike this replaces #95524.

  • The first set of patterns to convert the MPI dialect to LLVM.
  • Further conversion pattern will be added in future PRs.

This adds the following on top of #95524

  • Support for Intel MPI by introducing MPIImplTraits to distinguish MPI implementations
  • Works end-to-end in a downstream project using Intel MPI (going through Mesh dialect)
  • Lowering MPI_Send/MPI_Recv
  • Conversion will only be included if MPI is found

cc @mofeing @wsmoses @tobiasgrosser @hhkit


Patch is 44.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127053.diff

10 Files Affected:

  • (added) mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h (+31)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (+77-77)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+8-8)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPITypes.td (+1-1)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt (+34)
  • (added) mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h (+151)
  • (added) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+332)
  • (added) mlir/test/Conversion/MPIToLLVM/ops.mlir (+84)
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..8d2698aa91c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,31 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include "mlir/IR/DialectRegistry.h"
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns);
+
+void registerConvertMPIToLLVMInterface(DialectRegistry &registry);
+
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..df0cf9d518faf 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -42,104 +42,104 @@ def MPI_Dialect : Dialect {
 // Error classes enum:
 //===----------------------------------------------------------------------===//
 
-def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">;
-def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
-def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
-def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
-def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
+def MPI_CodeSuccess : I32EnumAttrCase<"_MPI_SUCCESS", 0, "MPI_SUCCESS">;
+def MPI_CodeErrAccess : I32EnumAttrCase<"_MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
+def MPI_CodeErrAmode : I32EnumAttrCase<"_MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
+def MPI_CodeErrArg : I32EnumAttrCase<"_MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
+def MPI_CodeErrAssert : I32EnumAttrCase<"_MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
 def MPI_CodeErrBadFile
-    : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
-def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
-def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
-def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
+    : I32EnumAttrCase<"_MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
+def MPI_CodeErrBase : I32EnumAttrCase<"_MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
+def MPI_CodeErrBuffer : I32EnumAttrCase<"_MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
+def MPI_CodeErrComm : I32EnumAttrCase<"_MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
 def MPI_CodeErrConversion
-    : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
-def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
-def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
-def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
+    : I32EnumAttrCase<"_MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
+def MPI_CodeErrCount : I32EnumAttrCase<"_MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
+def MPI_CodeErrDims : I32EnumAttrCase<"_MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
+def MPI_CodeErrDisp : I32EnumAttrCase<"_MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
 def MPI_CodeErrDupDatarep
-    : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
+    : I32EnumAttrCase<"_MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
 def MPI_CodeErrErrhandler
-    : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
-def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
+    : I32EnumAttrCase<"_MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
+def MPI_CodeErrFile : I32EnumAttrCase<"_MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
 def MPI_CodeErrFileExists
-    : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
+    : I32EnumAttrCase<"_MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
 def MPI_CodeErrFileInUse
-    : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
-def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
-def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
+    : I32EnumAttrCase<"_MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
+def MPI_CodeErrGroup : I32EnumAttrCase<"_MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
+def MPI_CodeErrInfo : I32EnumAttrCase<"_MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
 def MPI_CodeErrInfoKey
-    : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
 def MPI_CodeErrInfoNokey
-    : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
 def MPI_CodeErrInfoValue
-    : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
 def MPI_CodeErrInStatus
-    : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
-def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
-def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">;
-def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
+    : I32EnumAttrCase<"_MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
+def MPI_CodeErrIntern : I32EnumAttrCase<"_MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
+def MPI_CodeErrIo : I32EnumAttrCase<"_MPI_ERR_IO", 25, "MPI_ERR_IO">;
+def MPI_CodeErrKeyval : I32EnumAttrCase<"_MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
 def MPI_CodeErrLocktype
-    : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
-def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
-def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
+    : I32EnumAttrCase<"_MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
+def MPI_CodeErrName : I32EnumAttrCase<"_MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
+def MPI_CodeErrNoMem : I32EnumAttrCase<"_MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
 def MPI_CodeErrNoSpace
-    : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
+    : I32EnumAttrCase<"_MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
 def MPI_CodeErrNoSuchFile
-    : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
+    : I32EnumAttrCase<"_MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
 def MPI_CodeErrNotSame
-    : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
-def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">;
-def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
+    : I32EnumAttrCase<"_MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
+def MPI_CodeErrOp : I32EnumAttrCase<"_MPI_ERR_OP", 33, "MPI_ERR_OP">;
+def MPI_CodeErrOther : I32EnumAttrCase<"_MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
 def MPI_CodeErrPending
-    : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
-def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
+    : I32EnumAttrCase<"_MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
+def MPI_CodeErrPort : I32EnumAttrCase<"_MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
 def MPI_CodeErrProcAborted
-    : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
-def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
-def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
+    : I32EnumAttrCase<"_MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
+def MPI_CodeErrQuota : I32EnumAttrCase<"_MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
+def MPI_CodeErrRank : I32EnumAttrCase<"_MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
 def MPI_CodeErrReadOnly
-    : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
+    : I32EnumAttrCase<"_MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
 def MPI_CodeErrRequest
-    : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
+    : I32EnumAttrCase<"_MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
 def MPI_CodeErrRmaAttach
-    : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
 def MPI_CodeErrRmaConflict
-    : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
 def MPI_CodeErrRmaFlavor
-    : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
 def MPI_CodeErrRmaRange
-    : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
 def MPI_CodeErrRmaShared
-    : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
 def MPI_CodeErrRmaSync
-    : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
-def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
+def MPI_CodeErrRoot : I32EnumAttrCase<"_MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
 def MPI_CodeErrService
-    : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
+    : I32EnumAttrCase<"_MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
 def MPI_CodeErrSession
-    : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
-def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
-def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
-def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
+    : I32EnumAttrCase<"_MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
+def MPI_CodeErrSize : I32EnumAttrCase<"_MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
+def MPI_CodeErrSpawn : I32EnumAttrCase<"_MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
+def MPI_CodeErrTag : I32EnumAttrCase<"_MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
 def MPI_CodeErrTopology
-    : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
+    : I32EnumAttrCase<"_MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
 def MPI_CodeErrTruncate
-    : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
-def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
+    : I32EnumAttrCase<"_MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
+def MPI_CodeErrType : I32EnumAttrCase<"_MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
 def MPI_CodeErrUnknown
-    : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
+    : I32EnumAttrCase<"_MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
 def MPI_CodeErrUnsupportedDatarep
-    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58,
+    : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_DATAREP", 58,
                       "MPI_ERR_UNSUPPORTED_DATAREP">;
 def MPI_CodeErrUnsupportedOperation
-    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59,
+    : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_OPERATION", 59,
                       "MPI_ERR_UNSUPPORTED_OPERATION">;
 def MPI_CodeErrValueTooLarge
-    : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
-def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
+    : I32EnumAttrCase<"_MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
+def MPI_CodeErrWin : I32EnumAttrCase<"_MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
 def MPI_CodeErrLastcode
-    : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
+    : I32EnumAttrCase<"_MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
 
 def MPI_ErrorClassEnum
     : I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [
@@ -215,20 +215,20 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
-def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
-def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
-def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
-def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
-def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
-def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
-def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
-def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
-def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
-def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
-def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
-def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
-def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+def MPI_OpNull : I32EnumAttrCase<"_MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"_MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"_MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"_MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"_MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"_MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"_MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"_MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"_MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"_MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"_MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"_MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"_MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"_MPI_REPLACE", 13, "MPI_REPLACE">;
 
 def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
       MPI_OpNull,
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 284ba72af9768..db28bd09678f8 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -102,13 +102,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $dest
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
-                       "type($ref) `,` type($tag) `,` type($rank)"
+  let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($dest)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
@@ -154,11 +154,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_RecvOp : MPI_Op<"recv", []> {
-  let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+  let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
                 "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
-    from rank `dest`. The `tag` value and communicator enables the library to 
+    from rank `source`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same 
     ranks.
 
@@ -172,13 +172,13 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 
   let arguments = (
     ins AnyMemRef : $ref,
-    I32 : $tag, I32 : $rank
+    I32 : $tag, I32 : $source
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
-                       "type($ref) `,` type($tag) `,` type($rank)"
+  let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($source)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index fafea0eac8bb7..a55d30e778e22 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def MPI_Retval : MPI_Type<"Retval", "retval"> {
-  let summary = "MPI function call return value";
+  let summary = "MPI function call return value (!mpi.retval)";
   let description = [{
     This type represents a return value from an MPI function call.
     This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 887db344ed88b..6ab23ff86b3c6 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INITALLEXTENSIONS_H_
 #define MLIR_INITALLEXTENSIONS_H_
 
+#include "Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertFuncToLLVMInterface(registry);
   index::registerConvertIndexToLLVMInterface(registry);
   registerConvertMathToLLVMInterface(registry);
+  mpi::registerConvertMPIToLLVMInterface(registry);
   registerConvertMemRefToLLVMInterface(registry);
   registerConvertNVVMToLLVMInterface(registry);
   registerConvertOpenMPToLLVMInterface(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0bd08ec6333e6..3dc7472584cf9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(MeshToMPI)
+add_subdirectory(MPIToLLVM)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..17df603ff5686
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,34 @@
+find_path(MPI_C_HEADER_DIR mpi.h
+    PATHS $ENV{I_MPI_ROOT}/include
+          $ENV{MPI_HOME}/include
+          $ENV{MPI_ROOT}/include)
+if(MPI_C_HEADER_DIR)
+  # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+  message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+else()
+  message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+  return()
+endif()
+
+add_mlir_conversion_library(MLIRMPIToLLVM
+  MPIToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRMPIDialect
+)
+target_include_directories(
+    MLIRMPIToLLVM
+    PRIVATE
+    ${MPI_C_HEADER_DIR}
+)
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
new file mode 100644
index 0000000000000..09811e1cb7c61
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -0,0 +1,151 @@
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <mpi.h>
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+  // get/create MPI_COMM_WORLD as a mlir::Value
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter);
+  // get/create MPI datatype as a mlir::Value which corresponds to the given
+  // mlir::Type
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 cons...
[truncated]

@mofeing
Copy link
Contributor

mofeing commented Feb 13, 2025

Support for Intel MPI by introducing MPIImplTraits to distinguish MPI implementations

Next MPI standard stabilizes the ABI, but I guess we still need to support earlier vendor-ABIs of earlier MPI implementations right?

cc @JBlaschke @hhkit

@fschlimb
Copy link
Contributor Author

Support for Intel MPI by introducing MPIImplTraits to distinguish MPI implementations

Next MPI standard stabilizes the ABI, but I guess we still need to support earlier vendor-ABIs of earlier MPI implementations right?

cc @JBlaschke @hhkit

Yes, we can move to that once that is broadly available.

@fschlimb fschlimb requested a review from joker-eph February 13, 2025 14:58
@Dinistro Dinistro self-requested a review February 14, 2025 07:09
Copy link
Contributor

@AntonLydike AntonLydike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ Really cool work! Thanks @fschlimb! Super excited to see this progress!

Comment on lines 47 to 48
// CHECK-NEXT: [[v28:%.*]] = llvm.mlir.
// CHECK-NEXT: [[v29:%.*]] = llvm.mlir.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these appear to be missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the constants used in there depend on the MPI implementation. This makes sure the test passes for any MPI.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work and sorry for the late review.

Unfortunately, there are some issues with the dependencies on MPI headers and packages. I highly doubt that this is an acceptable thing to depend on in the build system. Note that test might break when someone has another header around that depends different values for the used constants.

I suggest to restructure the pass to take all the necessary values as configuration parameters. In the worst case, this could even be a json file, if the list of options is growing. I'm not really sure who could consult on this.

Comment on lines 18 to 19
#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Why is this required when there is no actual pass that backs this. This is an interface based pass extension, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left-over. Can be removed.

Comment on lines 22 to 23
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved into Conversion/LLVMCommon, I presume

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be problematic to include here given the different licence. I'm not sure how this should be handled, tbh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MPICH license is compatible because it is fully permissive. Should anyone see an issue with this, I can write it myself and make up my own values. We'd lose compatibility with MPICH/Intel MPI, though.

matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// get loc
auto loc = op.getLoc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Don't use auto when the type isn't exceptionally complex or given on the RHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guidance is: "Don’t “almost always” use auto, but do use auto with initializers like cast(...) or other places where the type is already obvious from the context."
I guess it lies in the eye of the beholder, but to me the type is clearly obvious from the context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many places in MLIR adapted a very clear stance on this. I'm aware that this rule is quite flexible, which makes it a bad one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I changed as you suggested.

Comment on lines 119 to 121
auto loc = op.getLoc();
auto context = rewriter.getContext();
auto i32 = rewriter.getI32Type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: No auto here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

func.func @mpi_test(%arg0: memref<100xf32>) {
// CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Only use CHECK-NEXT when necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 79 to 81
%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1

%5 = mpi.error_class %0 : !mpi.retval
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be checks for these as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently do not handle these. Can add a check nevertheless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@@ -0,0 +1,84 @@
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s

module {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The module is not required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@fschlimb
Copy link
Contributor Author

Very nice work and sorry for the late review.

Unfortunately, there are some issues with the dependencies on MPI headers and packages. I highly doubt that this is an acceptable thing to depend on in the build system. Note that test might break when someone has another header around that depends different values for the used constants.

I suggest to restructure the pass to take all the necessary values as configuration parameters. In the worst case, this could even be a json file, if the list of options is growing. I'm not really sure who could consult on this.

Yes, there is no straight-forward solution for this right now. We could depend on yet another (external) library which unifies the interfaces. To me that's not really tempting for getting a few defines.

Why would configuration parameters change anything about the issue you are mentioning? Whatever was configured can or cannot conflict later, no matter what. Notice that later is not when the compiler is built, but when the compiler is executed. If you are concerned about other passes/dialects using a different MPI in the same build, we could move finding the header-file to an upper level cmake file. Frankly, I don't see this as a problem.

As mentioned above, as soon as the new MPI standard defines a stable ABI this problem is gone anyway. I suggest keeping this like this for the time being.

@fschlimb
Copy link
Contributor Author

Thanks a lot @Dinistro for your valuable comments and questions.
I applied most of your suggestions and replied to others.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would configuration parameters change anything about the issue you are mentioning? Whatever was configured can or cannot conflict later, no matter what. Notice that later is not when the compiler is built, but when the compiler is executed. If you are concerned about other passes/dialects using a different MPI in the same build, we could move finding the header-file to an upper level cmake file. Frankly, I don't see this as a problem.

I was referring to lowering pass options that can be passed into the compiler, i.e., compiling to a specific MPI implementation/draft instead of tying the supported MPI implementation to the configuration time of MLIR itself.

In any case, I don't feel confident to judge if this is an acceptable workaround or not. Thus, I'm explicitly requesting someone from core to weight in. CC @joker-eph @matthias-springer @ftynse

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internal headers like this are used very infrequently, that's why I raised it.

Comment on lines 8 to 12
#ifdef FOUND_MPI_C_HEADER
#include <mpi.h>
#else // not FOUND_MPI_C_HEADER
#include "mpi_fallback.h"
#endif // FOUND_MPI_C_HEADER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a clear plan of adapting this general format? I understand that this can be annoying to implement, but not doing it now might result in more work down the line.

I know of multiple companies that distribute MLIR in binary form, mine included. For now, we will not rely on the MPI dialect, but this might become an issue at some point. Specifically, setting up the build containers to contain the correct MPI version we want to target will be problematic.

matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// get loc
auto loc = op.getLoc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many places in MLIR adapted a very clear stance on this. I'm aware that this rule is quite flexible, which makes it a bad one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adapt the name to follow the established format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined into cpp.


// CHECK: mpi.retval_check
%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
// CEHCK: mpi.error_class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CEHCK: mpi.error_class
// CHECK: mpi.error_class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval

// CHECK: mpi.retval_check
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe add a TODO that states that these are not yet lowered and will be in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the ops, they are unnecessary.

@tobiasgrosser
Copy link
Contributor

Why would configuration parameters change anything about the issue you are mentioning? Whatever was configured can or cannot conflict later, no matter what. Notice that later is not when the compiler is built, but when the compiler is executed. If you are concerned about other passes/dialects using a different MPI in the same build, we could move finding the header-file to an upper level cmake file. Frankly, I don't see this as a problem.

I was referring to lowering pass options that can be passed into the compiler, i.e., compiling to a specific MPI implementation/draft instead of tying the supported MPI implementation to the configuration time of MLIR itself.

In any case, I don't feel confident to judge if this is an acceptable workaround or not. Thus, I'm explicitly requesting someone from core to weight in. CC @joker-eph @matthias-springer @ftynse

Instead of linking in the MPI header file that is decided on by the pre-processor, could we make the MPI library a run-time choice? Instead of using a header file, I suggest we hardcode the ABI based on our knowledge. We can use a little python script to extract the ABI we need, but requiring the mpi.h header might not be a good idea. In the end, 90% of this PR already hardcodes the MPI API without relying on the MPI headers. I feel the cost of adding some ABI knowledge is small and the benefit of being independent of headers and in-line with MLIR-style are notable.

@fschlimb
Copy link
Contributor Author

fschlimb commented Feb 18, 2025

Instead of linking in the MPI header file that is decided on by the pre-processor, could we make the MPI library a run-time choice? Instead of using a header file, I suggest we hardcode the ABI based on our knowledge. We can use a little python script to extract the ABI we need, but requiring the mpi.h header might not be a good idea. In the end, 90% of this PR already hardcodes the MPI API without relying on the MPI headers. I feel the cost of adding some ABI knowledge is small and the benefit of being independent of headers and in-line with MLIR-style are notable.

Yes, as the "fallback" shows, currently there is not much we need to know. We can do the same for OpenMPI. Of course this raises the same license questions as for MPICH. I am convinced both licenses are sufficiently permissive.

Adding "runtime" option is of course more work (we must consider various versions of MPI implementations).

Again, I do not think the current approach is a super special case in MLIR/LLVM. Other passe/dialects/features are also tied to specific HW or SW at build time.

Any objections to addressing this in a separate PR?

@Dinistro
Copy link
Contributor

Any objections to addressing this in a separate PR?

IIUC, the current lowering pass changes behavior depending on the MPI version that was around during build time. If that is really the case, then the test should not work properly depending on the build setting, no?

@fschlimb
Copy link
Contributor Author

Any objections to addressing this in a separate PR?

IIUC, the current lowering pass changes behavior depending on the MPI version that was around during build time. If that is really the case, then the test should not work properly depending on the build setting, no?

The tests are written in a way, that the differences between OpenMPI and MPICH* are hidden. This is mostly about different constant values. @AntonLydike also found places where details of checked operations are missing (and only the operation name itself gets verified).

@fschlimb
Copy link
Contributor Author

fschlimb commented Feb 19, 2025

I thought a bit about the runtime option. I see only two ways of doing this in a reasonable way:

  1. I found no way of adding custom option to --convert-to-llvm. So if we want this to become an option to the pass, we need a separate pass (and not populate the generic --convert-to-llvm pass). There seems to be some consensus that the current approach is preferred (and In [MLIR][MPI] Add LLVM lowering patterns for some MPI operations #95524 @joker-eph explicitly asked for it). Any hint on how to add an option to --convert-to-llvm would be appreciated.
  2. Use DLTI to annotate the IR. This could be something like the below. I prefer this solution, it is more flexible (e.g. one could write an extra pass wich inserts the DLTI according to an option to the pass) and keeps the integration into --convert-to-llvm.
module attributes { mpi.dlti = #dlti.map<"MPI:Implemention" = "Intel"> } {
  %0 = mpi.....
}

Any thoughts on this?
cc @rengolin @rolfmorel

@tobiasgrosser
Copy link
Contributor

Nice!

Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@fschlimb
Copy link
Contributor Author

@Dinistro @tobiasgrosser I just (force-)pushed a solution which selects the MPI implementation at runtime. As suggested above, it uses DLTI. Hope this addresses your concerns.

@fschlimb fschlimb requested a review from Dinistro February 19, 2025 17:36
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks for changing this to use the DLTI. I mainly added nit comments, related to things from the style guide: https://mlir.llvm.org/getting_started/DeveloperGuide/

Considering the new structure, it might be more elegant to rely on virtual dispatch to implement the different lowering strategies. This use case very nicely fits the strategy pattern, I believe.
This would ensure that you only need to check the DLTI once and can thereafter rely on the virtual dispatch to delegate to the proper implementation.

Co-authored-by: Christian Ulmann <[email protected]>
@fschlimb
Copy link
Contributor Author

Very nice, thanks for changing this to use the DLTI. I mainly added nit comments, related to things from the style guide: https://mlir.llvm.org/getting_started/DeveloperGuide/

Considering the new structure, it might be more elegant to rely on virtual dispatch to implement the different lowering strategies. This use case very nicely fits the strategy pattern, I believe. This would ensure that you only need to check the DLTI once and can thereafter rely on the virtual dispatch to delegate to the proper implementation.

Yes, I have been thinking about this as well. Virtual dispatch requires an object. The easy lifetime of that object is each matchAndRewrite. What we ideally want is one object per module. Without a proper pass, I don't know how to do this.

Yes, the current solution is not perfect. We can go on like this forever, or we declare this as good enough as a first step and together work on continuous improvements in future PRs. I vote for the latter.

@fschlimb
Copy link
Contributor Author

Thanks a lot for your review and suggestions, @Dinistro !
I have converted the Traits classes to use virtual functions for dispatch. Each conversion pattern invocation creates its own trait object (once). This is not idea in terms of performance, but much better than before.
I hope the latest commit addresses all your remaining concerns and you are ok with merging.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work, thanks for addressing all the comments. I added a few more nits but this is essentially ready to go 🙂

Comment on lines +9 to +13
//
// Copyright (C) by Argonne National Laboratory
// See COPYRIGHT in top-level directory
// of MPICH source repository.
//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is no longer needed?

Comment on lines 58 to 59
/// Instantiate a new MPIImplTraits object according to the DLTI attribute
/// on the given module.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe clarify that the fallback is MPICH

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@fschlimb
Copy link
Contributor Author

Excellent. Thanks again @AntonLydike, @Dinistro, @tobiasgrosser and @mofeing for your comments!

@fschlimb fschlimb merged commit ab166d4 into llvm:main Feb 21, 2025
11 checks passed
@kazutakahirata
Copy link
Contributor

kazutakahirata commented Feb 21, 2025

@fschlimb With this PR, I am getting:

/usr/lib/gcc/x86_64-linux-gnu/14/../../../../include/c++/14/bits/unique_ptr.h:93:2: error: delete called on '(anonymous namespace)::MPIImplTraits' that is abstract but has non-virtual destructor [-Werror,-Wdelete-abstract-non-virtual-dtor]
        delete __ptr;
        ^
/usr/lib/gcc/x86_64-linux-gnu/14/../../../../include/c++/14/bits/unique_ptr.h:398:4: note: in instantiation of member function 'std::default_delete<(anonymous namespace)::MPIImplTraits>::operator()' requested here
          get_deleter()(std::move(__ptr));
          ^
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp:313:22: note: in instantiation of member function 'std::unique_ptr<(anonymous namespace)::MPIImplTraits>::~unique_ptr' requested here
    auto mpiTraits = MPIImplTraits::get(moduleOp);

Is there any way you could take a look? My cmake configuration is -DLLVM_ENABLE_WERROR=On with clang being the host compiler. Thanks in advance!

@kazutakahirata
Copy link
Contributor

@fschlimb I've fixed the warnings with 386a45c. Thanks!

@fschlimb
Copy link
Contributor Author

@fschlimb I've fixed the warnings with 386a45c. Thanks!

@kazutakahirata Thanks so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants