Skip to content

Commit d151987

Browse files
committed
[mlir][python] expose python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 3b84b1e commit d151987

File tree

4 files changed

+140
-1
lines changed

4 files changed

+140
-1
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ extern "C" {
2222
MLIR_CAPI_EXPORTED void
2323
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
2424

25+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
26+
27+
struct MlirLinalgContractionDimensions {
28+
MlirAttribute batch;
29+
MlirAttribute m;
30+
MlirAttribute n;
31+
MlirAttribute k;
32+
};
33+
34+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
35+
mlirLinalgInferContractionDimensions(MlirOperation op);
36+
2537
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
2638

2739
#ifdef __cplusplus

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,45 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir-c/BuiltinAttributes.h"
910
#include "mlir-c/Dialect/Linalg.h"
1011
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1212
#include "mlir/Bindings/Python/Nanobind.h"
13+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1314

1415
namespace nb = nanobind;
16+
using namespace mlir::python::nanobind_adaptors;
17+
18+
struct PyContractionDimensions {
19+
MlirLinalgContractionDimensions value;
20+
21+
PyContractionDimensions() = default;
22+
PyContractionDimensions(const MlirLinalgContractionDimensions &v)
23+
: value(v) {}
24+
};
25+
26+
static std::optional<PyContractionDimensions>
27+
mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
28+
MlirLinalgContractionDimensions dims =
29+
mlirLinalgInferContractionDimensions(op);
30+
31+
// Detect "empty" result.
32+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
33+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
34+
return std::nullopt;
35+
}
36+
return PyContractionDimensions{dims};
37+
}
38+
39+
static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
40+
std::vector<int32_t> result;
41+
int64_t size = mlirDenseArrayGetNumElements(attr);
42+
result.reserve(size);
43+
for (int64_t i = 0; i < size; ++i) {
44+
result.push_back(mlirDenseI32ArrayGetElement(attr, i));
45+
}
46+
return result;
47+
}
1548

1649
static void populateDialectLinalgSubmodule(nb::module_ m) {
1750
m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
2053
nb::arg("op"),
2154
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2255
"op.");
56+
57+
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
58+
"Checks if the given operation is a Linalg contraction operation.",
59+
nb::arg("op"));
60+
61+
nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
62+
.def_prop_ro("batch",
63+
[](const PyContractionDimensions &self) {
64+
return convertDenseI32AttrToList(self.value.batch);
65+
})
66+
.def_prop_ro("m",
67+
[](const PyContractionDimensions &self) {
68+
return convertDenseI32AttrToList(self.value.m);
69+
})
70+
.def_prop_ro("n",
71+
[](const PyContractionDimensions &self) {
72+
return convertDenseI32AttrToList(self.value.n);
73+
})
74+
.def_prop_ro("k", [](const PyContractionDimensions &self) {
75+
return convertDenseI32AttrToList(self.value.k);
76+
});
77+
78+
m.def("infer_contraction_dimensions",
79+
&mlirLinalgInferContractionDimensionsBinding,
80+
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
81+
"op.",
82+
nb::arg("op"));
2383
}
2484

2585
NB_MODULE(_mlirDialectsLinalg, m) {

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4141
fun(b, *body, op->getAttrs());
4242
}
4343

44+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
45+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
46+
return linalg::isaContractionOpInterface(linalgOp);
47+
}
48+
49+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
50+
mlirLinalgInferContractionDimensions(MlirOperation op) {
51+
MlirLinalgContractionDimensions result{};
52+
auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
53+
if (!linalgOp)
54+
return result;
55+
56+
auto maybeDims = linalg::inferContractionDims(linalgOp);
57+
if (failed(maybeDims))
58+
return result;
59+
60+
linalg::ContractionDimensions contractionDims = maybeDims.value();
61+
MLIRContext *ctx = linalgOp.getContext();
62+
63+
auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
64+
SmallVector<int32_t> intVals(vals.begin(), vals.end());
65+
return wrap(DenseI32ArrayAttr::get(ctx, intVals));
66+
};
67+
68+
result.batch = toAttr(contractionDims.batch);
69+
result.m = toAttr(contractionDims.m);
70+
result.n = toAttr(contractionDims.n);
71+
result.k = toAttr(contractionDims.k);
72+
73+
return result;
74+
}
75+
4476
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,38 @@ def tensor_pack(src, dst):
606606
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607607
# CHECK: }
608608
print(module)
609+
610+
611+
@run
612+
def test_infer_contraction_dimensions():
613+
with Context(), Location.unknown():
614+
module = ir.Module.parse(
615+
r"""
616+
module {
617+
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
618+
-> tensor<4x4xf32> {
619+
%cst = arith.constant 0.0 : f32
620+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
621+
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
622+
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
623+
return %1 : tensor<4x4xf32>
624+
}
625+
}
626+
"""
627+
)
628+
func_op = module.body.operations[0]
629+
body_block = func_op.regions[0].blocks[0]
630+
fill_op = body_block.operations[1]
631+
matmul_op = body_block.operations[2]
632+
633+
assert not linalg.isa_contraction_op(fill_op)
634+
assert linalg.isa_contraction_op(matmul_op)
635+
636+
dims = linalg.infer_contraction_dimensions(fill_op)
637+
assert dims is None
638+
dims = linalg.infer_contraction_dimensions(matmul_op)
639+
assert dims
640+
641+
assert dims.m == [0], f"Expected m=[0], got {dims.m}"
642+
assert dims.n == [1], f"Expected n=[1], got {dims.n}"
643+
assert dims.k == [2], f"Expected k=[2], got {dims.k}"

0 commit comments

Comments
 (0)