-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims #135253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Bangtian Liu (bangtianliu) ChangesThis PR is mainly about exposing the python bindings for Full diff: https://github.com/llvm/llvm-project/pull/135253.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index c57d193e62d25..8715739473f6c 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,7 +22,7 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
-MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op);
struct MlirLinalgContractionDimensions {
MlirAttribute batch;
@@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op);
+
+struct MlirLinalgConvolutionDimensions {
+ MlirAttribute batch;
+ MlirAttribute outputImage;
+ MlirAttribute outputChannel;
+ MlirAttribute filterLoop;
+ MlirAttribute inputChannel;
+ MlirAttribute depth;
+ MlirAttribute strides;
+ MlirAttribute dilations;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
+mlirLinalgInferConvolutionDimensions(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 978ea8664b6b9..d98bfd9f2d979 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
return dims;
}
+static std::optional<MlirLinalgConvolutionDimensions>
+InferConvolutionDimensions(MlirOperation op) {
+ MlirLinalgConvolutionDimensions dims =
+ mlirLinalgInferConvolutionDimensions(op);
+
+ // Detect "empty" result. This occurs when `op` is not a convolution op,
+ // or when `linalg::inferConvolutionDims` fails.
+ if (mlirAttributeIsNull(dims.batch) &&
+ mlirAttributeIsNull(dims.outputImage) &&
+ mlirAttributeIsNull(dims.outputChannel) &&
+ mlirAttributeIsNull(dims.filterLoop) &&
+ mlirAttributeIsNull(dims.inputChannel) &&
+ mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
+ mlirAttributeIsNull(dims.dilations)) {
+ return std::nullopt;
+ }
+
+ return dims;
+}
+
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"fill_builtin_region",
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
- m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+ m.def("isa_contraction_op", &mlirLinalgIsaContractionOp,
"Checks if the given operation is a Linalg contraction operation.",
nb::arg("op"));
@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
"op.",
nb::arg("op"));
+
+ m.def("isa_convolution_op", &mlirLinalgIsaConvolutionOp,
+ "Checks if the given operation is a Linalg convolution operation.",
+ nb::arg("op"));
+
+ nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+ .def_prop_ro("batch",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.batch;
+ })
+ .def_prop_ro("output_image",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.outputImage;
+ })
+ .def_prop_ro("output_channel",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.outputChannel;
+ })
+ .def_prop_ro("filter_loop",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.filterLoop;
+ })
+ .def_prop_ro("input_channel",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.inputChannel;
+ })
+ .def_prop_ro("depth",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.depth;
+ })
+ .def_prop_ro("strides",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.strides;
+ })
+ .def_prop_ro("dilations",
+ [](const MlirLinalgConvolutionDimensions &self) {
+ return self.dilations;
+ });
+
+ m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
+ "Infers convolution dimensions", nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 362b89bdef6c9..737d7e6e68641 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}
-MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
// isaContractionOpInterface handles null linalgOp internally.
return linalg::isaContractionOpInterface(linalgOp);
@@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op) {
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return false;
+
+ return linalg::isaConvolutionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
+mlirLinalgInferConvolutionDimensions(MlirOperation op) {
+ MlirLinalgConvolutionDimensions result{};
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return result;
+
+ FailureOr<linalg::ConvolutionDimensions> maybeDims =
+ linalg::inferConvolutionDims(linalgOp);
+ if (failed(maybeDims))
+ return result;
+
+ linalg::ConvolutionDimensions dims = *maybeDims;
+ MLIRContext *ctx = linalgOp.getContext();
+
+ auto toI32Attr =
+ [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
+ };
+
+ auto toI64Attr =
+ [&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute {
+ return wrap(DenseI64ArrayAttr::get(ctx, vals));
+ };
+
+ result.batch = toI32Attr(dims.batch);
+ result.outputImage = toI32Attr(dims.outputImage);
+ result.outputChannel = toI32Attr(dims.outputChannel);
+ result.filterLoop = toI32Attr(dims.filterLoop);
+ result.inputChannel = toI32Attr(dims.inputChannel);
+ result.depth = toI32Attr(dims.depth);
+ result.strides = toI64Attr(dims.strides);
+ result.dilations = toI64Attr(dims.dilations);
+
+ return result;
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index a48aa90fa5836..98157b0e443cf 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -52,7 +52,7 @@ def contraction_fn(a, b, c):
dims = linalg.infer_contraction_dimensions(contraction_op)
assert dims is not None
- # Expect m=[0], n=[1], k=[2] as per standard matmul
+ # Expect m=[0], n=[1], k=[2] as per standard matmul.
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
@@ -95,3 +95,67 @@ def dynamic_contraction_fn(a, b, c):
assert (
list(dims.batch) == []
), f"Expected batch=[], got {list(dims.batch)}"
+
+
+@run
+def test_infer_convolution_dimensions_from_ops():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+
+ with InsertionPoint(module.body):
+ # === Static shapes ===
+ batch, h, w, c_in, kh, kw, c_out = 1, 8, 8, 4, 3, 3, 16
+ input_type = RankedTensorType.get((batch, h, w, c_in), f32)
+ filter_type = RankedTensorType.get((kh, kw, c_in, c_out), f32)
+ output_type = RankedTensorType.get(
+ (batch, h - kh + 1, w - kw + 1, c_out), f32
+ )
+
+ @func.FuncOp.from_py_func(input_type, filter_type, output_type)
+ def conv_fn(input, filter, output):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[output])
+ fill_op = filled.owner
+
+ assert not linalg.isa_convolution_op(fill_op)
+ assert linalg.infer_convolution_dimensions(fill_op) is None
+
+ result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
+ conv_op = result.owner
+
+ assert linalg.isa_convolution_op(conv_op)
+ dims = linalg.infer_convolution_dimensions(conv_op)
+ assert dims is not None
+ assert list(dims.batch) == [0]
+ assert list(dims.output_image) == [1, 2]
+ assert list(dims.output_channel) == [3]
+ assert list(dims.filter_loop) == [4, 5]
+ assert list(dims.input_channel) == [6]
+ assert list(dims.depth) == []
+ assert list(dims.strides) == [1, 1]
+ assert list(dims.dilations) == [1, 1]
+
+ # === Dynamic shapes ===
+ dyn = ShapedType.get_dynamic_size()
+ dyn_input_type = RankedTensorType.get((batch, dyn, dyn, c_in), f32)
+ dyn_output_type = RankedTensorType.get((batch, dyn, dyn, c_out), f32)
+
+ @func.FuncOp.from_py_func(dyn_input_type, filter_type, dyn_output_type)
+ def dyn_conv_fn(input, filter, output):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[output])
+ result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
+ conv_op = result.owner
+
+ assert linalg.isa_convolution_op(conv_op)
+ dims = linalg.infer_convolution_dimensions(conv_op)
+ assert dims is not None
+ assert list(dims.batch) == [0]
+ assert list(dims.output_image) == [1, 2]
+ assert list(dims.output_channel) == [3]
+ assert list(dims.filter_loop) == [4, 5]
+ assert list(dims.input_channel) == [6]
+ assert list(dims.depth) == []
+ assert list(dims.strides) == [1, 1]
+ assert list(dims.dilations) == [1, 1]
|
…ace and linalg::inferConvolutionDims Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
this PR seems to cause breakage in downstream CAPI users /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/mlir/lib/../include/mlir-c/Dialect/Linalg.h:34:20: error: must use 'struct' tag to refer to type 'MlirLinalgContractionDimensions'
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
^
struct
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/mlir/lib/../include/mlir-c/Dialect/Linalg.h:50:20: error: must use 'struct' tag to refer to type 'MlirLinalgConvolutionDimensions'
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
^
struct |
Thanks for pointing that out. I've just submitted a PR to fix it. |
This PR is after #135253 and #134935 to fix the error reported by #135253 (comment). This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu <[email protected]>
…ge (#135380) This PR is after #135253 and #134935 to fix the error reported by llvm/llvm-project#135253 (comment). This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu <[email protected]>
…utionOpInterface and linalg::inferConvolutionDims (llvm#135253) This PR is mainly about exposing the python bindings for `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
…5380) This PR is after llvm#135253 and llvm#134935 to fix the error reported by llvm#135253 (comment). This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu <[email protected]>
This PR is mainly about exposing the python bindings for
linalg::isaConvolutionOpInterface
andlinalg::inferConvolutionDims
.