Skip to content

Commit 24fc211

Browse files
committed
address reviewer comments
Signed-off-by: Bangtian Liu <[email protected]>
1 parent d151987 commit 24fc211

File tree

4 files changed

+121
-78
lines changed

4 files changed

+121
-78
lines changed

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir-c/BuiltinAttributes.h"
109
#include "mlir-c/Dialect/Linalg.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir/Bindings/Python/Nanobind.h"
@@ -15,35 +14,18 @@
1514
namespace nb = nanobind;
1615
using namespace mlir::python::nanobind_adaptors;
1716

18-
struct PyContractionDimensions {
19-
MlirLinalgContractionDimensions value;
20-
21-
PyContractionDimensions() = default;
22-
PyContractionDimensions(const MlirLinalgContractionDimensions &v)
23-
: value(v) {}
24-
};
25-
26-
static std::optional<PyContractionDimensions>
27-
mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
17+
static std::optional<MlirLinalgContractionDimensions>
18+
InferContractionDimensions(MlirOperation op) {
2819
MlirLinalgContractionDimensions dims =
2920
mlirLinalgInferContractionDimensions(op);
3021

31-
// Detect "empty" result.
22+
// Detect "empty" result. This occurs when `op` is not a contraction op,
23+
// or when `linalg::inferContractionDims` fails.
3224
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
3325
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
3426
return std::nullopt;
3527
}
36-
return PyContractionDimensions{dims};
37-
}
38-
39-
static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
40-
std::vector<int32_t> result;
41-
int64_t size = mlirDenseArrayGetNumElements(attr);
42-
result.reserve(size);
43-
for (int64_t i = 0; i < size; ++i) {
44-
result.push_back(mlirDenseI32ArrayGetElement(attr, i));
45-
}
46-
return result;
28+
return dims;
4729
}
4830

4931
static void populateDialectLinalgSubmodule(nb::module_ m) {
@@ -58,25 +40,22 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5840
"Checks if the given operation is a Linalg contraction operation.",
5941
nb::arg("op"));
6042

61-
nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
43+
nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
6244
.def_prop_ro("batch",
63-
[](const PyContractionDimensions &self) {
64-
return convertDenseI32AttrToList(self.value.batch);
65-
})
66-
.def_prop_ro("m",
67-
[](const PyContractionDimensions &self) {
68-
return convertDenseI32AttrToList(self.value.m);
69-
})
70-
.def_prop_ro("n",
71-
[](const PyContractionDimensions &self) {
72-
return convertDenseI32AttrToList(self.value.n);
45+
[](const MlirLinalgContractionDimensions &self) {
46+
return self.batch;
7347
})
74-
.def_prop_ro("k", [](const PyContractionDimensions &self) {
75-
return convertDenseI32AttrToList(self.value.k);
48+
.def_prop_ro(
49+
"m",
50+
[](const MlirLinalgContractionDimensions &self) { return self.m; })
51+
.def_prop_ro(
52+
"n",
53+
[](const MlirLinalgContractionDimensions &self) { return self.n; })
54+
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
55+
return self.k;
7656
});
7757

78-
m.def("infer_contraction_dimensions",
79-
&mlirLinalgInferContractionDimensionsBinding,
58+
m.def("infer_contraction_dimensions", &InferContractionDimensions,
8059
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
8160
"op.",
8261
nb::arg("op"));

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4343

4444
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
4545
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
46+
// isaContractionOpInterface handles null linalgOp internally.
4647
return linalg::isaContractionOpInterface(linalgOp);
4748
}
4849

@@ -53,16 +54,17 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
5354
if (!linalgOp)
5455
return result;
5556

56-
auto maybeDims = linalg::inferContractionDims(linalgOp);
57+
FailureOr<linalg::ContractionDimensions> maybeDims =
58+
linalg::inferContractionDims(linalgOp);
5759
if (failed(maybeDims))
5860
return result;
5961

60-
linalg::ContractionDimensions contractionDims = maybeDims.value();
62+
linalg::ContractionDimensions contractionDims = *maybeDims;
6163
MLIRContext *ctx = linalgOp.getContext();
6264

63-
auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
64-
SmallVector<int32_t> intVals(vals.begin(), vals.end());
65-
return wrap(DenseI32ArrayAttr::get(ctx, intVals));
65+
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
66+
return wrap(
67+
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
6668
};
6769

6870
result.batch = toAttr(contractionDims.batch);

mlir/test/python/dialects/linalg/ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -606,38 +606,3 @@ def tensor_pack(src, dst):
606606
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607607
# CHECK: }
608608
print(module)
609-
610-
611-
@run
612-
def test_infer_contraction_dimensions():
613-
with Context(), Location.unknown():
614-
module = ir.Module.parse(
615-
r"""
616-
module {
617-
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
618-
-> tensor<4x4xf32> {
619-
%cst = arith.constant 0.0 : f32
620-
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
621-
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
622-
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
623-
return %1 : tensor<4x4xf32>
624-
}
625-
}
626-
"""
627-
)
628-
func_op = module.body.operations[0]
629-
body_block = func_op.regions[0].blocks[0]
630-
fill_op = body_block.operations[1]
631-
matmul_op = body_block.operations[2]
632-
633-
assert not linalg.isa_contraction_op(fill_op)
634-
assert linalg.isa_contraction_op(matmul_op)
635-
636-
dims = linalg.infer_contraction_dimensions(fill_op)
637-
assert dims is None
638-
dims = linalg.infer_contraction_dimensions(matmul_op)
639-
assert dims
640-
641-
assert dims.m == [0], f"Expected m=[0], got {dims.m}"
642-
assert dims.n == [1], f"Expected n=[1], got {dims.n}"
643-
assert dims.k == [2], f"Expected k=[2], got {dims.k}"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# RUN: %PYTHON %s
2+
3+
from mlir.dialects import arith, func, linalg
4+
from mlir.dialects.linalg.opdsl.lang import *
5+
from mlir.ir import *
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
f()
11+
return f
12+
13+
14+
@run
15+
def test_infer_contraction_dimensions_from_ops():
16+
with Context(), Location.unknown():
17+
module = Module.create()
18+
f32 = F32Type.get()
19+
with InsertionPoint(module.body):
20+
# === Static shapes ===
21+
m, n, k = 4, 4, 4
22+
a_type = RankedTensorType.get((m, k), f32)
23+
b_type = RankedTensorType.get((k, n), f32)
24+
c_type = RankedTensorType.get((m, n), f32)
25+
26+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
27+
def contraction_fn(a, b, c):
28+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
29+
filled = linalg.fill(zero, outs=[c])
30+
fill_op = filled.owner
31+
32+
assert not linalg.isa_contraction_op(zero.operation)
33+
assert not linalg.isa_contraction_op(fill_op)
34+
assert linalg.infer_contraction_dimensions(fill_op) is None
35+
36+
dim_m = AffineDimExpr.get(0)
37+
dim_n = AffineDimExpr.get(1)
38+
dim_k = AffineDimExpr.get(2)
39+
40+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
41+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
42+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
43+
result = linalg.contract(
44+
a,
45+
b,
46+
outs=(filled,),
47+
indexing_maps=[a_map, b_map, c_map],
48+
)
49+
contraction_op = result.owner
50+
51+
assert linalg.isa_contraction_op(contraction_op)
52+
dims = linalg.infer_contraction_dimensions(contraction_op)
53+
assert dims is not None
54+
55+
# Expect m=[0], n=[1], k=[2] as per standard matmul
56+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
57+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
58+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
59+
assert (
60+
list(dims.batch) == []
61+
), f"Expected batch=[], got {list(dims.batch)}"
62+
63+
# === Dynamic shape case ===
64+
dyn = ShapedType.get_dynamic_size()
65+
a_dyn_type = RankedTensorType.get((4, dyn), f32)
66+
b_dyn_type = RankedTensorType.get((dyn, 4), f32)
67+
c_type = RankedTensorType.get((4, 4), f32)
68+
69+
@func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
70+
def dynamic_contraction_fn(a, b, c):
71+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
72+
filled = linalg.fill(zero, outs=[c])
73+
dim_m = AffineDimExpr.get(0)
74+
dim_n = AffineDimExpr.get(1)
75+
dim_k = AffineDimExpr.get(2)
76+
77+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
78+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
79+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
80+
81+
result = linalg.contract(
82+
a,
83+
b,
84+
outs=(filled,),
85+
indexing_maps=[a_map, b_map, c_map],
86+
)
87+
contraction_op = result.owner
88+
89+
assert linalg.isa_contraction_op(contraction_op)
90+
dims = linalg.infer_contraction_dimensions(contraction_op)
91+
assert dims is not None
92+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
93+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
94+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
95+
assert (
96+
list(dims.batch) == []
97+
), f"Expected batch=[], got {list(dims.batch)}"

0 commit comments

Comments
 (0)