Skip to content

[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims #134935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir-c/Dialect/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);

struct MlirLinalgContractionDimensions {
MlirAttribute batch;
MlirAttribute m;
MlirAttribute n;
MlirAttribute k;
};

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

#ifdef __cplusplus
Expand Down
41 changes: 40 additions & 1 deletion mlir/lib/Bindings/Python/DialectLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,25 @@

#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;

static std::optional<MlirLinalgContractionDimensions>
InferContractionDimensions(MlirOperation op) {
MlirLinalgContractionDimensions dims =
mlirLinalgInferContractionDimensions(op);

// Detect "empty" result. This occurs when `op` is not a contraction op,
// or when `linalg::inferContractionDims` fails.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
return dims;
}

static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
Expand All @@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");

m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
"Checks if the given operation is a Linalg contraction operation.",
nb::arg("op"));

nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
.def_prop_ro("batch",
[](const MlirLinalgContractionDimensions &self) {
return self.batch;
})
.def_prop_ro(
"m",
[](const MlirLinalgContractionDimensions &self) { return self.m; })
.def_prop_ro(
"n",
[](const MlirLinalgContractionDimensions &self) { return self.n; })
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
return self.k;
});

m.def("infer_contraction_dimensions", &InferContractionDimensions,
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
"op.",
nb::arg("op"));
}

NB_MODULE(_mlirDialectsLinalg, m) {
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
// isaContractionOpInterface handles null linalgOp internally.
return linalg::isaContractionOpInterface(linalgOp);
}

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op) {
MlirLinalgContractionDimensions result{};
auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return result;

FailureOr<linalg::ContractionDimensions> maybeDims =
linalg::inferContractionDims(linalgOp);
if (failed(maybeDims))
return result;

linalg::ContractionDimensions contractionDims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();

auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
return wrap(
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
};

result.batch = toAttr(contractionDims.batch);
result.m = toAttr(contractionDims.m);
result.n = toAttr(contractionDims.n);
result.k = toAttr(contractionDims.k);

return result;
}

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
97 changes: 97 additions & 0 deletions mlir/test/python/dialects/linalg/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# RUN: %PYTHON %s

from mlir.dialects import arith, func, linalg
from mlir.dialects.linalg.opdsl.lang import *
from mlir.ir import *


def run(f):
print("\nTEST:", f.__name__)
f()
return f


@run
def test_infer_contraction_dimensions_from_ops():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
# === Static shapes ===
m, n, k = 4, 4, 4
a_type = RankedTensorType.get((m, k), f32)
b_type = RankedTensorType.get((k, n), f32)
c_type = RankedTensorType.get((m, n), f32)

@func.FuncOp.from_py_func(a_type, b_type, c_type)
def contraction_fn(a, b, c):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
filled = linalg.fill(zero, outs=[c])
fill_op = filled.owner

assert not linalg.isa_contraction_op(zero.operation)
assert not linalg.isa_contraction_op(fill_op)
assert linalg.infer_contraction_dimensions(fill_op) is None

dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)

a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
result = linalg.contract(
a,
b,
outs=(filled,),
indexing_maps=[a_map, b_map, c_map],
)
contraction_op = result.owner

assert linalg.isa_contraction_op(contraction_op)
dims = linalg.infer_contraction_dimensions(contraction_op)
assert dims is not None

# Expect m=[0], n=[1], k=[2] as per standard matmul
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
assert (
list(dims.batch) == []
), f"Expected batch=[], got {list(dims.batch)}"

# === Dynamic shape case ===
dyn = ShapedType.get_dynamic_size()
a_dyn_type = RankedTensorType.get((4, dyn), f32)
b_dyn_type = RankedTensorType.get((dyn, 4), f32)
c_type = RankedTensorType.get((4, 4), f32)

@func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
def dynamic_contraction_fn(a, b, c):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
filled = linalg.fill(zero, outs=[c])
dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)

a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])

result = linalg.contract(
a,
b,
outs=(filled,),
indexing_maps=[a_map, b_map, c_map],
)
contraction_op = result.owner

assert linalg.isa_contraction_op(contraction_op)
dims = linalg.infer_contraction_dimensions(contraction_op)
assert dims is not None
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
assert (
list(dims.batch) == []
), f"Expected batch=[], got {list(dims.batch)}"