Skip to content

[mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims #135253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion mlir/include/mlir-c/Dialect/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op);

struct MlirLinalgContractionDimensions {
MlirAttribute batch;
Expand All @@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);

MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);

struct MlirLinalgConvolutionDimensions {
MlirAttribute batch;
MlirAttribute outputImage;
MlirAttribute outputChannel;
MlirAttribute filterLoop;
MlirAttribute inputChannel;
MlirAttribute depth;
MlirAttribute strides;
MlirAttribute dilations;
};

MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
mlirLinalgInferConvolutionDimensions(MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

#ifdef __cplusplus
Expand Down
63 changes: 62 additions & 1 deletion mlir/lib/Bindings/Python/DialectLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
return dims;
}

static std::optional<MlirLinalgConvolutionDimensions>
InferConvolutionDimensions(MlirOperation op) {
MlirLinalgConvolutionDimensions dims =
mlirLinalgInferConvolutionDimensions(op);

// Detect "empty" result. This occurs when `op` is not a convolution op,
// or when `linalg::inferConvolutionDims` fails.
if (mlirAttributeIsNull(dims.batch) &&
mlirAttributeIsNull(dims.outputImage) &&
mlirAttributeIsNull(dims.outputChannel) &&
mlirAttributeIsNull(dims.filterLoop) &&
mlirAttributeIsNull(dims.inputChannel) &&
mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
mlirAttributeIsNull(dims.dilations)) {
return std::nullopt;
}

return dims;
}

static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"fill_builtin_region",
Expand All @@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");

m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
"Checks if the given operation is a Linalg contraction operation.",
nb::arg("op"));

Expand All @@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
"op.",
nb::arg("op"));

m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));

nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
.def_prop_ro("batch",
[](const MlirLinalgConvolutionDimensions &self) {
return self.batch;
})
.def_prop_ro("output_image",
[](const MlirLinalgConvolutionDimensions &self) {
return self.outputImage;
})
.def_prop_ro("output_channel",
[](const MlirLinalgConvolutionDimensions &self) {
return self.outputChannel;
})
.def_prop_ro("filter_loop",
[](const MlirLinalgConvolutionDimensions &self) {
return self.filterLoop;
})
.def_prop_ro("input_channel",
[](const MlirLinalgConvolutionDimensions &self) {
return self.inputChannel;
})
.def_prop_ro("depth",
[](const MlirLinalgConvolutionDimensions &self) {
return self.depth;
})
.def_prop_ro("strides",
[](const MlirLinalgConvolutionDimensions &self) {
return self.strides;
})
.def_prop_ro("dilations",
[](const MlirLinalgConvolutionDimensions &self) {
return self.dilations;
});

m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));
}

NB_MODULE(_mlirDialectsLinalg, m) {
Expand Down
47 changes: 46 additions & 1 deletion mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
// isaContractionOpInterface handles null linalgOp internally.
return linalg::isaContractionOpInterface(linalgOp);
Expand Down Expand Up @@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return false;

return linalg::isaConvolutionOpInterface(linalgOp);
}

MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
mlirLinalgInferConvolutionDimensions(MlirOperation op) {
MlirLinalgConvolutionDimensions result{};
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return result;

FailureOr<linalg::ConvolutionDimensions> maybeDims =
linalg::inferConvolutionDims(linalgOp);
if (failed(maybeDims))
return result;

linalg::ConvolutionDimensions dims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();

auto toI32Attr =
[&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};

auto toI64Attr =
[&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute {
return wrap(DenseI64ArrayAttr::get(ctx, vals));
};

result.batch = toI32Attr(dims.batch);
result.outputImage = toI32Attr(dims.outputImage);
result.outputChannel = toI32Attr(dims.outputChannel);
result.filterLoop = toI32Attr(dims.filterLoop);
result.inputChannel = toI32Attr(dims.inputChannel);
result.depth = toI32Attr(dims.depth);
result.strides = toI64Attr(dims.strides);
result.dilations = toI64Attr(dims.dilations);

return result;
}

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
66 changes: 65 additions & 1 deletion mlir/test/python/dialects/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def contraction_fn(a, b, c):
dims = linalg.infer_contraction_dimensions(contraction_op)
assert dims is not None

# Expect m=[0], n=[1], k=[2] as per standard matmul
# Expect m=[0], n=[1], k=[2] as per standard matmul.
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
Expand Down Expand Up @@ -95,3 +95,67 @@ def dynamic_contraction_fn(a, b, c):
assert (
list(dims.batch) == []
), f"Expected batch=[], got {list(dims.batch)}"


@run
def test_infer_convolution_dimensions_from_ops():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()

with InsertionPoint(module.body):
# === Static shapes ===
batch, h, w, c_in, kh, kw, c_out = 1, 8, 8, 4, 3, 3, 16
input_type = RankedTensorType.get((batch, h, w, c_in), f32)
filter_type = RankedTensorType.get((kh, kw, c_in, c_out), f32)
output_type = RankedTensorType.get(
(batch, h - kh + 1, w - kw + 1, c_out), f32
)

@func.FuncOp.from_py_func(input_type, filter_type, output_type)
def conv_fn(input, filter, output):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
filled = linalg.fill(zero, outs=[output])
fill_op = filled.owner

assert not linalg.isa_convolution_op(fill_op)
assert linalg.infer_convolution_dimensions(fill_op) is None

result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
conv_op = result.owner

assert linalg.isa_convolution_op(conv_op)
dims = linalg.infer_convolution_dimensions(conv_op)
assert dims is not None
assert list(dims.batch) == [0]
assert list(dims.output_image) == [1, 2]
assert list(dims.output_channel) == [3]
assert list(dims.filter_loop) == [4, 5]
assert list(dims.input_channel) == [6]
assert list(dims.depth) == []
assert list(dims.strides) == [1, 1]
assert list(dims.dilations) == [1, 1]

# === Dynamic shapes ===
dyn = ShapedType.get_dynamic_size()
dyn_input_type = RankedTensorType.get((batch, dyn, dyn, c_in), f32)
dyn_output_type = RankedTensorType.get((batch, dyn, dyn, c_out), f32)

@func.FuncOp.from_py_func(dyn_input_type, filter_type, dyn_output_type)
def dyn_conv_fn(input, filter, output):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
filled = linalg.fill(zero, outs=[output])
result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
conv_op = result.owner

assert linalg.isa_convolution_op(conv_op)
dims = linalg.infer_convolution_dimensions(conv_op)
assert dims is not None
assert list(dims.batch) == [0]
assert list(dims.output_image) == [1, 2]
assert list(dims.output_channel) == [3]
assert list(dims.filter_loop) == [4, 5]
assert list(dims.input_channel) == [6]
assert list(dims.depth) == []
assert list(dims.strides) == [1, 1]
assert list(dims.dilations) == [1, 1]