Skip to content

[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims #134935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025

Conversation

bangtianliu
Copy link
Contributor

@bangtianliu bangtianliu commented Apr 8, 2025

This PR is mainly about exposing the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims.

@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Bangtian Liu (bangtianliu)

Changes

This PR is mainly about exposing the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims


Full diff: https://github.com/llvm/llvm-project/pull/134935.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Linalg.h (+12)
  • (modified) mlir/lib/Bindings/Python/DialectLinalg.cpp (+61-1)
  • (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+32)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+33)
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
 MLIR_CAPI_EXPORTED void
 mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+  MlirAttribute batch;
+  MlirAttribute m;
+  MlirAttribute n;
+  MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..0dbd4f18b7212 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,12 +6,45 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+struct PyContractionDimensions {
+  MlirLinalgContractionDimensions value;
+
+  PyContractionDimensions() = default;
+  PyContractionDimensions(const MlirLinalgContractionDimensions &v)
+      : value(v) {}
+};
+
+static std::optional<PyContractionDimensions>
+mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+  MlirLinalgContractionDimensions dims =
+      mlirLinalgInferContractionDimensions(op);
+
+  // Detect "empty" result.
+  if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+      mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+    return std::nullopt;
+  }
+  return PyContractionDimensions{dims};
+}
+
+static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
+  std::vector<int32_t> result;
+  int64_t size = mlirDenseArrayGetNumElements(attr);
+  result.reserve(size);
+  for (int64_t i = 0; i < size; ++i) {
+    result.push_back(mlirDenseI32ArrayGetElement(attr, i));
+  }
+  return result;
+}
 
 static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
       nb::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
+
+  m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+        "Checks if the given operation is a Linalg contraction operation.",
+        nb::arg("op"));
+
+  nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+      .def_prop_ro("batch",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.batch);
+                   })
+      .def_prop_ro("m",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.m);
+                   })
+      .def_prop_ro("n",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.n);
+                   })
+      .def_prop_ro("k", [](const PyContractionDimensions &self) {
+        return convertDenseI32AttrToList(self.value.k);
+      });
+
+  m.def("infer_contraction_dimensions",
+        &mlirLinalgInferContractionDimensionsBinding,
+        "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+        "op.",
+        nb::arg("op"));
 }
 
 NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..7e053d1188f24 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   fun(b, *body, op->getAttrs());
 }
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+  auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+  MlirLinalgContractionDimensions result{};
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return result;
+
+  auto maybeDims = linalg::inferContractionDims(linalgOp);
+  if (failed(maybeDims))
+    return result;
+
+  linalg::ContractionDimensions contractionDims = maybeDims.value();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+    SmallVector<int32_t> intVals(vals.begin(), vals.end());
+    return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+  };
+
+  result.batch = toAttr(contractionDims.batch);
+  result.m = toAttr(contractionDims.m);
+  result.n = toAttr(contractionDims.n);
+  result.k = toAttr(contractionDims.k);
+
+  return result;
+}
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..3129a9bbe1d8a 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,36 @@ def tensor_pack(src, dst):
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         print(module)
+
+
+@run
+def test_infer_contraction_dimensions():
+    with Context(), Location.unknown():
+        module = ir.Module.parse(r"""
+            module {
+                func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
+                    -> tensor<4x4xf32> {
+                    %cst = arith.constant 0.0 : f32
+                    %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+                    %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) 
+                        outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+                    return %1 : tensor<4x4xf32>
+                }
+            }
+        """)
+        func_op = module.body.operations[0]
+        body_block = func_op.regions[0].blocks[0]
+        fill_op = body_block.operations[1]
+        matmul_op = body_block.operations[2]
+
+        assert not linalg.isa_contraction_op(fill_op)
+        assert linalg.isa_contraction_op(matmul_op)
+
+        dims = linalg.infer_contraction_dimensions(fill_op)
+        assert dims is None
+        dims = linalg.infer_contraction_dimensions(matmul_op)
+        assert dims
+
+        assert dims.m == [0], f"Expected m=[0], got {dims.m}"
+        assert dims.n == [1], f"Expected n=[1], got {dims.n}"
+        assert dims.k == [2], f"Expected k=[2], got {dims.k}"

Copy link

github-actions bot commented Apr 8, 2025

✅ With the latest revision this PR passed the Python code formatter.

@bangtianliu bangtianliu force-pushed the expose_python_binding branch from 29f4f99 to d151987 Compare April 8, 2025 22:09
…erface and linalg::inferContractionDims

Signed-off-by: Bangtian Liu <[email protected]>
@kuhar kuhar requested review from kuhar and Max191 April 8, 2025 22:56
Signed-off-by: Bangtian Liu <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, looks solid now.

Please wait for an approval from @makslevental before merging.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good - looks clean. thanks dude

@kuhar kuhar merged commit c359f76 into llvm:main Apr 10, 2025
11 checks passed
AllinLeeYL pushed a commit to AllinLeeYL/llvm-project that referenced this pull request Apr 10, 2025
…ctionOpInterface and linalg::inferContractionDims (llvm#134935)

This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.

---------

Signed-off-by: Bangtian Liu <[email protected]>
kuhar pushed a commit that referenced this pull request Apr 11, 2025
This PR is after #135253 and #134935 to fix the error reported by
#135253 (comment).
This PR Adds typedef declarations for `MlirLinalgContractionDimensions`
and `MlirLinalgConvolutionDimensions` in the C API to ensure
compatibility with pure C code.

I confirm that this fix resolves the reported error based on my testing.

Signed-off-by: Bangtian Liu <[email protected]>
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…ctionOpInterface and linalg::inferContractionDims (llvm#134935)

This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.

---------

Signed-off-by: Bangtian Liu <[email protected]>
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…5380)

This PR is after llvm#135253 and llvm#134935 to fix the error reported by
llvm#135253 (comment).
This PR Adds typedef declarations for `MlirLinalgContractionDimensions`
and `MlirLinalgConvolutionDimensions` in the C API to ensure
compatibility with pure C code.

I confirm that this fix resolves the reported error based on my testing.

Signed-off-by: Bangtian Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants