Skip to content

Commit 113cd15

Browse files
committed
address reviewer comments
Signed-off-by: Bangtian Liu <[email protected]>
1 parent d151987 commit 113cd15

File tree

4 files changed

+119
-77
lines changed

4 files changed

+119
-77
lines changed

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir-c/BuiltinAttributes.h"
109
#include "mlir-c/Dialect/Linalg.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir/Bindings/Python/Nanobind.h"
@@ -15,16 +14,8 @@
1514
namespace nb = nanobind;
1615
using namespace mlir::python::nanobind_adaptors;
1716

18-
struct PyContractionDimensions {
19-
MlirLinalgContractionDimensions value;
20-
21-
PyContractionDimensions() = default;
22-
PyContractionDimensions(const MlirLinalgContractionDimensions &v)
23-
: value(v) {}
24-
};
25-
26-
static std::optional<PyContractionDimensions>
27-
mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
17+
static std::optional<MlirLinalgContractionDimensions>
18+
InferContractionDimensions(MlirOperation op) {
2819
MlirLinalgContractionDimensions dims =
2920
mlirLinalgInferContractionDimensions(op);
3021

@@ -33,17 +24,7 @@ mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
3324
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
3425
return std::nullopt;
3526
}
36-
return PyContractionDimensions{dims};
37-
}
38-
39-
static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
40-
std::vector<int32_t> result;
41-
int64_t size = mlirDenseArrayGetNumElements(attr);
42-
result.reserve(size);
43-
for (int64_t i = 0; i < size; ++i) {
44-
result.push_back(mlirDenseI32ArrayGetElement(attr, i));
45-
}
46-
return result;
27+
return dims;
4728
}
4829

4930
static void populateDialectLinalgSubmodule(nb::module_ m) {
@@ -58,25 +39,22 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5839
"Checks if the given operation is a Linalg contraction operation.",
5940
nb::arg("op"));
6041

61-
nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
42+
nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
6243
.def_prop_ro("batch",
63-
[](const PyContractionDimensions &self) {
64-
return convertDenseI32AttrToList(self.value.batch);
65-
})
66-
.def_prop_ro("m",
67-
[](const PyContractionDimensions &self) {
68-
return convertDenseI32AttrToList(self.value.m);
69-
})
70-
.def_prop_ro("n",
71-
[](const PyContractionDimensions &self) {
72-
return convertDenseI32AttrToList(self.value.n);
44+
[](const MlirLinalgContractionDimensions &self) {
45+
return self.batch;
7346
})
74-
.def_prop_ro("k", [](const PyContractionDimensions &self) {
75-
return convertDenseI32AttrToList(self.value.k);
47+
.def_prop_ro(
48+
"m",
49+
[](const MlirLinalgContractionDimensions &self) { return self.m; })
50+
.def_prop_ro(
51+
"n",
52+
[](const MlirLinalgContractionDimensions &self) { return self.n; })
53+
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
54+
return self.k;
7655
});
7756

78-
m.def("infer_contraction_dimensions",
79-
&mlirLinalgInferContractionDimensionsBinding,
57+
m.def("infer_contraction_dimensions", &InferContractionDimensions,
8058
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
8159
"op.",
8260
nb::arg("op"));

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4343

4444
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
4545
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
46+
// isaContractionOpInterface handles null linalgOp internally.
4647
return linalg::isaContractionOpInterface(linalgOp);
4748
}
4849

@@ -53,16 +54,17 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
5354
if (!linalgOp)
5455
return result;
5556

56-
auto maybeDims = linalg::inferContractionDims(linalgOp);
57+
FailureOr<linalg::ContractionDimensions> maybeDims =
58+
linalg::inferContractionDims(linalgOp);
5759
if (failed(maybeDims))
5860
return result;
5961

60-
linalg::ContractionDimensions contractionDims = maybeDims.value();
62+
linalg::ContractionDimensions contractionDims = *maybeDims;
6163
MLIRContext *ctx = linalgOp.getContext();
6264

63-
auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
64-
SmallVector<int32_t> intVals(vals.begin(), vals.end());
65-
return wrap(DenseI32ArrayAttr::get(ctx, intVals));
65+
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
66+
return wrap(
67+
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
6668
};
6769

6870
result.batch = toAttr(contractionDims.batch);

mlir/test/python/dialects/linalg/ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -606,38 +606,3 @@ def tensor_pack(src, dst):
606606
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607607
# CHECK: }
608608
print(module)
609-
610-
611-
@run
612-
def test_infer_contraction_dimensions():
613-
with Context(), Location.unknown():
614-
module = ir.Module.parse(
615-
r"""
616-
module {
617-
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
618-
-> tensor<4x4xf32> {
619-
%cst = arith.constant 0.0 : f32
620-
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
621-
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
622-
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
623-
return %1 : tensor<4x4xf32>
624-
}
625-
}
626-
"""
627-
)
628-
func_op = module.body.operations[0]
629-
body_block = func_op.regions[0].blocks[0]
630-
fill_op = body_block.operations[1]
631-
matmul_op = body_block.operations[2]
632-
633-
assert not linalg.isa_contraction_op(fill_op)
634-
assert linalg.isa_contraction_op(matmul_op)
635-
636-
dims = linalg.infer_contraction_dimensions(fill_op)
637-
assert dims is None
638-
dims = linalg.infer_contraction_dimensions(matmul_op)
639-
assert dims
640-
641-
assert dims.m == [0], f"Expected m=[0], got {dims.m}"
642-
assert dims.n == [1], f"Expected n=[1], got {dims.n}"
643-
assert dims.k == [2], f"Expected k=[2], got {dims.k}"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# RUN: %PYTHON %s
2+
3+
from mlir.dialects import arith, func, linalg
4+
from mlir.dialects.linalg.opdsl.lang import *
5+
from mlir.ir import *
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
f()
11+
return f
12+
13+
14+
@run
15+
def test_infer_contraction_dimensions_from_ops():
16+
with Context(), Location.unknown():
17+
module = Module.create()
18+
f32 = F32Type.get()
19+
with InsertionPoint(module.body):
20+
# === Static shapes ===
21+
m, n, k = 4, 4, 4
22+
a_type = RankedTensorType.get((m, k), f32)
23+
b_type = RankedTensorType.get((k, n), f32)
24+
c_type = RankedTensorType.get((m, n), f32)
25+
26+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
27+
def contraction_fn(a, b, c):
28+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
29+
filled = linalg.fill(zero, outs=[c])
30+
fill_op = filled.owner
31+
32+
assert not linalg.isa_contraction_op(zero.operation)
33+
assert not linalg.isa_contraction_op(fill_op)
34+
assert linalg.infer_contraction_dimensions(fill_op) is None
35+
36+
dim_m = AffineDimExpr.get(0)
37+
dim_n = AffineDimExpr.get(1)
38+
dim_k = AffineDimExpr.get(2)
39+
40+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
41+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
42+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
43+
result = linalg.contract(
44+
a,
45+
b,
46+
outs=(filled,),
47+
indexing_maps=[a_map, b_map, c_map],
48+
)
49+
contraction_op = result.owner
50+
51+
assert linalg.isa_contraction_op(contraction_op)
52+
dims = linalg.infer_contraction_dimensions(contraction_op)
53+
assert dims is not None
54+
55+
# Expect m=[0], n=[1], k=[2] as per standard matmul
56+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
57+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
58+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
59+
assert (
60+
list(dims.batch) == []
61+
), f"Expected batch=[], got {list(dims.batch)}"
62+
63+
# === Dynamic shape case ===
64+
dyn = ShapedType.get_dynamic_size()
65+
a_dyn_type = RankedTensorType.get((4, dyn), f32)
66+
b_dyn_type = RankedTensorType.get((dyn, 4), f32)
67+
c_type = RankedTensorType.get((4, 4), f32)
68+
69+
@func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
70+
def dynamic_contraction_fn(a, b, c):
71+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
72+
filled = linalg.fill(zero, outs=[c])
73+
dim_m = AffineDimExpr.get(0)
74+
dim_n = AffineDimExpr.get(1)
75+
dim_k = AffineDimExpr.get(2)
76+
77+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
78+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
79+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
80+
81+
result = linalg.contract(
82+
a,
83+
b,
84+
outs=(filled,),
85+
indexing_maps=[a_map, b_map, c_map],
86+
)
87+
contraction_op = result.owner
88+
89+
assert linalg.isa_contraction_op(contraction_op)
90+
dims = linalg.infer_contraction_dimensions(contraction_op)
91+
assert dims is not None
92+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
93+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
94+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
95+
assert (
96+
list(dims.batch) == []
97+
), f"Expected batch=[], got {list(dims.batch)}"

0 commit comments

Comments
 (0)