Skip to content

Commit c359f76

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)
This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 9b50167 commit c359f76

File tree

4 files changed

+183
-1
lines changed

4 files changed

+183
-1
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ extern "C" {
2222
MLIR_CAPI_EXPORTED void
2323
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
2424

25+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
26+
27+
struct MlirLinalgContractionDimensions {
28+
MlirAttribute batch;
29+
MlirAttribute m;
30+
MlirAttribute n;
31+
MlirAttribute k;
32+
};
33+
34+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
35+
mlirLinalgInferContractionDimensions(MlirOperation op);
36+
2537
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
2638

2739
#ifdef __cplusplus

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,25 @@
88

99
#include "mlir-c/Dialect/Linalg.h"
1010
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1211
#include "mlir/Bindings/Python/Nanobind.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace nb = nanobind;
15+
using namespace mlir::python::nanobind_adaptors;
16+
17+
static std::optional<MlirLinalgContractionDimensions>
18+
InferContractionDimensions(MlirOperation op) {
19+
MlirLinalgContractionDimensions dims =
20+
mlirLinalgInferContractionDimensions(op);
21+
22+
// Detect "empty" result. This occurs when `op` is not a contraction op,
23+
// or when `linalg::inferContractionDims` fails.
24+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26+
return std::nullopt;
27+
}
28+
return dims;
29+
}
1530

1631
static void populateDialectLinalgSubmodule(nb::module_ m) {
1732
m.def(
@@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
2035
nb::arg("op"),
2136
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2237
"op.");
38+
39+
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
40+
"Checks if the given operation is a Linalg contraction operation.",
41+
nb::arg("op"));
42+
43+
nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
44+
.def_prop_ro("batch",
45+
[](const MlirLinalgContractionDimensions &self) {
46+
return self.batch;
47+
})
48+
.def_prop_ro(
49+
"m",
50+
[](const MlirLinalgContractionDimensions &self) { return self.m; })
51+
.def_prop_ro(
52+
"n",
53+
[](const MlirLinalgContractionDimensions &self) { return self.n; })
54+
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
55+
return self.k;
56+
});
57+
58+
m.def("infer_contraction_dimensions", &InferContractionDimensions,
59+
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
60+
"op.",
61+
nb::arg("op"));
2362
}
2463

2564
NB_MODULE(_mlirDialectsLinalg, m) {

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4141
fun(b, *body, op->getAttrs());
4242
}
4343

44+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
45+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
46+
// isaContractionOpInterface handles null linalgOp internally.
47+
return linalg::isaContractionOpInterface(linalgOp);
48+
}
49+
50+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
51+
mlirLinalgInferContractionDimensions(MlirOperation op) {
52+
MlirLinalgContractionDimensions result{};
53+
auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
54+
if (!linalgOp)
55+
return result;
56+
57+
FailureOr<linalg::ContractionDimensions> maybeDims =
58+
linalg::inferContractionDims(linalgOp);
59+
if (failed(maybeDims))
60+
return result;
61+
62+
linalg::ContractionDimensions contractionDims = *maybeDims;
63+
MLIRContext *ctx = linalgOp.getContext();
64+
65+
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
66+
return wrap(
67+
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
68+
};
69+
70+
result.batch = toAttr(contractionDims.batch);
71+
result.m = toAttr(contractionDims.m);
72+
result.n = toAttr(contractionDims.n);
73+
result.k = toAttr(contractionDims.k);
74+
75+
return result;
76+
}
77+
4478
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# RUN: %PYTHON %s
2+
3+
from mlir.dialects import arith, func, linalg
4+
from mlir.dialects.linalg.opdsl.lang import *
5+
from mlir.ir import *
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
f()
11+
return f
12+
13+
14+
@run
15+
def test_infer_contraction_dimensions_from_ops():
16+
with Context(), Location.unknown():
17+
module = Module.create()
18+
f32 = F32Type.get()
19+
with InsertionPoint(module.body):
20+
# === Static shapes ===
21+
m, n, k = 4, 4, 4
22+
a_type = RankedTensorType.get((m, k), f32)
23+
b_type = RankedTensorType.get((k, n), f32)
24+
c_type = RankedTensorType.get((m, n), f32)
25+
26+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
27+
def contraction_fn(a, b, c):
28+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
29+
filled = linalg.fill(zero, outs=[c])
30+
fill_op = filled.owner
31+
32+
assert not linalg.isa_contraction_op(zero.operation)
33+
assert not linalg.isa_contraction_op(fill_op)
34+
assert linalg.infer_contraction_dimensions(fill_op) is None
35+
36+
dim_m = AffineDimExpr.get(0)
37+
dim_n = AffineDimExpr.get(1)
38+
dim_k = AffineDimExpr.get(2)
39+
40+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
41+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
42+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
43+
result = linalg.contract(
44+
a,
45+
b,
46+
outs=(filled,),
47+
indexing_maps=[a_map, b_map, c_map],
48+
)
49+
contraction_op = result.owner
50+
51+
assert linalg.isa_contraction_op(contraction_op)
52+
dims = linalg.infer_contraction_dimensions(contraction_op)
53+
assert dims is not None
54+
55+
# Expect m=[0], n=[1], k=[2] as per standard matmul
56+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
57+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
58+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
59+
assert (
60+
list(dims.batch) == []
61+
), f"Expected batch=[], got {list(dims.batch)}"
62+
63+
# === Dynamic shape case ===
64+
dyn = ShapedType.get_dynamic_size()
65+
a_dyn_type = RankedTensorType.get((4, dyn), f32)
66+
b_dyn_type = RankedTensorType.get((dyn, 4), f32)
67+
c_type = RankedTensorType.get((4, 4), f32)
68+
69+
@func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
70+
def dynamic_contraction_fn(a, b, c):
71+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
72+
filled = linalg.fill(zero, outs=[c])
73+
dim_m = AffineDimExpr.get(0)
74+
dim_n = AffineDimExpr.get(1)
75+
dim_k = AffineDimExpr.get(2)
76+
77+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
78+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
79+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
80+
81+
result = linalg.contract(
82+
a,
83+
b,
84+
outs=(filled,),
85+
indexing_maps=[a_map, b_map, c_map],
86+
)
87+
contraction_op = result.owner
88+
89+
assert linalg.isa_contraction_op(contraction_op)
90+
dims = linalg.infer_contraction_dimensions(contraction_op)
91+
assert dims is not None
92+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
93+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
94+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
95+
assert (
96+
list(dims.batch) == []
97+
), f"Expected batch=[], got {list(dims.batch)}"

0 commit comments

Comments
 (0)