Skip to content

Commit 9466cbd

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims (#135253)
This PR is mainly about exposing the python bindings for `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 1d8966e commit 9466cbd

File tree

4 files changed

+190
-4
lines changed

4 files changed

+190
-4
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extern "C" {
2222
MLIR_CAPI_EXPORTED void
2323
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
2424

25-
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
25+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op);
2626

2727
struct MlirLinalgContractionDimensions {
2828
MlirAttribute batch;
@@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions {
3434
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
3535
mlirLinalgInferContractionDimensions(MlirOperation op);
3636

37+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
38+
39+
struct MlirLinalgConvolutionDimensions {
40+
MlirAttribute batch;
41+
MlirAttribute outputImage;
42+
MlirAttribute outputChannel;
43+
MlirAttribute filterLoop;
44+
MlirAttribute inputChannel;
45+
MlirAttribute depth;
46+
MlirAttribute strides;
47+
MlirAttribute dilations;
48+
};
49+
50+
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
51+
mlirLinalgInferConvolutionDimensions(MlirOperation op);
52+
3753
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
3854

3955
#ifdef __cplusplus

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
2828
return dims;
2929
}
3030

31+
static std::optional<MlirLinalgConvolutionDimensions>
32+
InferConvolutionDimensions(MlirOperation op) {
33+
MlirLinalgConvolutionDimensions dims =
34+
mlirLinalgInferConvolutionDimensions(op);
35+
36+
// Detect "empty" result. This occurs when `op` is not a convolution op,
37+
// or when `linalg::inferConvolutionDims` fails.
38+
if (mlirAttributeIsNull(dims.batch) &&
39+
mlirAttributeIsNull(dims.outputImage) &&
40+
mlirAttributeIsNull(dims.outputChannel) &&
41+
mlirAttributeIsNull(dims.filterLoop) &&
42+
mlirAttributeIsNull(dims.inputChannel) &&
43+
mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
44+
mlirAttributeIsNull(dims.dilations)) {
45+
return std::nullopt;
46+
}
47+
48+
return dims;
49+
}
50+
3151
static void populateDialectLinalgSubmodule(nb::module_ m) {
3252
m.def(
3353
"fill_builtin_region",
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
3656
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
3757
"op.");
3858

39-
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
59+
m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
4060
"Checks if the given operation is a Linalg contraction operation.",
4161
nb::arg("op"));
4262

@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5979
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
6080
"op.",
6181
nb::arg("op"));
82+
83+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84+
"Checks if the given operation is a Linalg convolution operation.",
85+
nb::arg("op"));
86+
87+
nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
88+
.def_prop_ro("batch",
89+
[](const MlirLinalgConvolutionDimensions &self) {
90+
return self.batch;
91+
})
92+
.def_prop_ro("output_image",
93+
[](const MlirLinalgConvolutionDimensions &self) {
94+
return self.outputImage;
95+
})
96+
.def_prop_ro("output_channel",
97+
[](const MlirLinalgConvolutionDimensions &self) {
98+
return self.outputChannel;
99+
})
100+
.def_prop_ro("filter_loop",
101+
[](const MlirLinalgConvolutionDimensions &self) {
102+
return self.filterLoop;
103+
})
104+
.def_prop_ro("input_channel",
105+
[](const MlirLinalgConvolutionDimensions &self) {
106+
return self.inputChannel;
107+
})
108+
.def_prop_ro("depth",
109+
[](const MlirLinalgConvolutionDimensions &self) {
110+
return self.depth;
111+
})
112+
.def_prop_ro("strides",
113+
[](const MlirLinalgConvolutionDimensions &self) {
114+
return self.strides;
115+
})
116+
.def_prop_ro("dilations",
117+
[](const MlirLinalgConvolutionDimensions &self) {
118+
return self.dilations;
119+
});
120+
121+
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122+
"Infers convolution dimensions", nb::arg("op"));
62123
}
63124

64125
NB_MODULE(_mlirDialectsLinalg, m) {

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4141
fun(b, *body, op->getAttrs());
4242
}
4343

44-
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
44+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
4545
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
4646
// isaContractionOpInterface handles null linalgOp internally.
4747
return linalg::isaContractionOpInterface(linalgOp);
@@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
7575
return result;
7676
}
7777

78+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
79+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
80+
if (!linalgOp)
81+
return false;
82+
83+
return linalg::isaConvolutionOpInterface(linalgOp);
84+
}
85+
86+
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
87+
mlirLinalgInferConvolutionDimensions(MlirOperation op) {
88+
MlirLinalgConvolutionDimensions result{};
89+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
90+
if (!linalgOp)
91+
return result;
92+
93+
FailureOr<linalg::ConvolutionDimensions> maybeDims =
94+
linalg::inferConvolutionDims(linalgOp);
95+
if (failed(maybeDims))
96+
return result;
97+
98+
linalg::ConvolutionDimensions dims = *maybeDims;
99+
MLIRContext *ctx = linalgOp.getContext();
100+
101+
auto toI32Attr =
102+
[&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
103+
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
104+
};
105+
106+
auto toI64Attr =
107+
[&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute {
108+
return wrap(DenseI64ArrayAttr::get(ctx, vals));
109+
};
110+
111+
result.batch = toI32Attr(dims.batch);
112+
result.outputImage = toI32Attr(dims.outputImage);
113+
result.outputChannel = toI32Attr(dims.outputChannel);
114+
result.filterLoop = toI32Attr(dims.filterLoop);
115+
result.inputChannel = toI32Attr(dims.inputChannel);
116+
result.depth = toI32Attr(dims.depth);
117+
result.strides = toI64Attr(dims.strides);
118+
result.dilations = toI64Attr(dims.dilations);
119+
120+
return result;
121+
}
122+
78123
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

mlir/test/python/dialects/linalg/utils.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def contraction_fn(a, b, c):
5252
dims = linalg.infer_contraction_dimensions(contraction_op)
5353
assert dims is not None
5454

55-
# Expect m=[0], n=[1], k=[2] as per standard matmul
55+
# Expect m=[0], n=[1], k=[2] as per standard matmul.
5656
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
5757
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
5858
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
@@ -95,3 +95,67 @@ def dynamic_contraction_fn(a, b, c):
9595
assert (
9696
list(dims.batch) == []
9797
), f"Expected batch=[], got {list(dims.batch)}"
98+
99+
100+
@run
101+
def test_infer_convolution_dimensions_from_ops():
102+
with Context(), Location.unknown():
103+
module = Module.create()
104+
f32 = F32Type.get()
105+
106+
with InsertionPoint(module.body):
107+
# === Static shapes ===
108+
batch, h, w, c_in, kh, kw, c_out = 1, 8, 8, 4, 3, 3, 16
109+
input_type = RankedTensorType.get((batch, h, w, c_in), f32)
110+
filter_type = RankedTensorType.get((kh, kw, c_in, c_out), f32)
111+
output_type = RankedTensorType.get(
112+
(batch, h - kh + 1, w - kw + 1, c_out), f32
113+
)
114+
115+
@func.FuncOp.from_py_func(input_type, filter_type, output_type)
116+
def conv_fn(input, filter, output):
117+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
118+
filled = linalg.fill(zero, outs=[output])
119+
fill_op = filled.owner
120+
121+
assert not linalg.isa_convolution_op(fill_op)
122+
assert linalg.infer_convolution_dimensions(fill_op) is None
123+
124+
result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
125+
conv_op = result.owner
126+
127+
assert linalg.isa_convolution_op(conv_op)
128+
dims = linalg.infer_convolution_dimensions(conv_op)
129+
assert dims is not None
130+
assert list(dims.batch) == [0]
131+
assert list(dims.output_image) == [1, 2]
132+
assert list(dims.output_channel) == [3]
133+
assert list(dims.filter_loop) == [4, 5]
134+
assert list(dims.input_channel) == [6]
135+
assert list(dims.depth) == []
136+
assert list(dims.strides) == [1, 1]
137+
assert list(dims.dilations) == [1, 1]
138+
139+
# === Dynamic shapes ===
140+
dyn = ShapedType.get_dynamic_size()
141+
dyn_input_type = RankedTensorType.get((batch, dyn, dyn, c_in), f32)
142+
dyn_output_type = RankedTensorType.get((batch, dyn, dyn, c_out), f32)
143+
144+
@func.FuncOp.from_py_func(dyn_input_type, filter_type, dyn_output_type)
145+
def dyn_conv_fn(input, filter, output):
146+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
147+
filled = linalg.fill(zero, outs=[output])
148+
result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
149+
conv_op = result.owner
150+
151+
assert linalg.isa_convolution_op(conv_op)
152+
dims = linalg.infer_convolution_dimensions(conv_op)
153+
assert dims is not None
154+
assert list(dims.batch) == [0]
155+
assert list(dims.output_image) == [1, 2]
156+
assert list(dims.output_channel) == [3]
157+
assert list(dims.filter_loop) == [4, 5]
158+
assert list(dims.input_channel) == [6]
159+
assert list(dims.depth) == []
160+
assert list(dims.strides) == [1, 1]
161+
assert list(dims.dilations) == [1, 1]

0 commit comments

Comments
 (0)