-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims #134935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Bangtian Liu (bangtianliu) ChangesThis PR is mainly about exposing the python bindings for Full diff: https://github.com/llvm/llvm-project/pull/134935.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+ MlirAttribute batch;
+ MlirAttribute m;
+ MlirAttribute n;
+ MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..0dbd4f18b7212 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,12 +6,45 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+struct PyContractionDimensions {
+ MlirLinalgContractionDimensions value;
+
+ PyContractionDimensions() = default;
+ PyContractionDimensions(const MlirLinalgContractionDimensions &v)
+ : value(v) {}
+};
+
+static std::optional<PyContractionDimensions>
+mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensions(op);
+
+ // Detect "empty" result.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return PyContractionDimensions{dims};
+}
+
+static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
+ std::vector<int32_t> result;
+ int64_t size = mlirDenseArrayGetNumElements(attr);
+ result.reserve(size);
+ for (int64_t i = 0; i < size; ++i) {
+ result.push_back(mlirDenseI32ArrayGetElement(attr, i));
+ }
+ return result;
+}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
+
+ m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
+
+ nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro("batch",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.batch);
+ })
+ .def_prop_ro("m",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.m);
+ })
+ .def_prop_ro("n",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.n);
+ })
+ .def_prop_ro("k", [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.k);
+ });
+
+ m.def("infer_contraction_dimensions",
+ &mlirLinalgInferContractionDimensionsBinding,
+ "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+ "op.",
+ nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..7e053d1188f24 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+ MlirLinalgContractionDimensions result{};
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return result;
+
+ auto maybeDims = linalg::inferContractionDims(linalgOp);
+ if (failed(maybeDims))
+ return result;
+
+ linalg::ContractionDimensions contractionDims = maybeDims.value();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ SmallVector<int32_t> intVals(vals.begin(), vals.end());
+ return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..3129a9bbe1d8a 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,36 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
+
+
+@run
+def test_infer_contraction_dimensions():
+ with Context(), Location.unknown():
+ module = ir.Module.parse(r"""
+ module {
+ func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
+ -> tensor<4x4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
+ outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+ }
+ }
+ """)
+ func_op = module.body.operations[0]
+ body_block = func_op.regions[0].blocks[0]
+ fill_op = body_block.operations[1]
+ matmul_op = body_block.operations[2]
+
+ assert not linalg.isa_contraction_op(fill_op)
+ assert linalg.isa_contraction_op(matmul_op)
+
+ dims = linalg.infer_contraction_dimensions(fill_op)
+ assert dims is None
+ dims = linalg.infer_contraction_dimensions(matmul_op)
+ assert dims
+
+ assert dims.m == [0], f"Expected m=[0], got {dims.m}"
+ assert dims.n == [1], f"Expected n=[1], got {dims.n}"
+ assert dims.k == [2], f"Expected k=[2], got {dims.k}"
|
✅ With the latest revision this PR passed the Python code formatter. |
29f4f99
to
d151987
Compare
…erface and linalg::inferContractionDims Signed-off-by: Bangtian Liu <[email protected]>
113cd15
to
24fc211
Compare
Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, looks solid now.
Please wait for an approval from @makslevental before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good - looks clean. thanks dude
…ctionOpInterface and linalg::inferContractionDims (llvm#134935) This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
This PR is after #135253 and #134935 to fix the error reported by #135253 (comment). This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu <[email protected]>
…ctionOpInterface and linalg::inferContractionDims (llvm#134935) This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
…5380) This PR is after llvm#135253 and llvm#134935 to fix the error reported by llvm#135253 (comment). This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu <[email protected]>
This PR is mainly about exposing the python bindings for
linalg::isaContractionOpInterface
andlinalg::inferContractionDims
.