Skip to content

Commit f796bc6

Browse files
authored
[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377)
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes _before_. 2) tablegen for linalg.contract did not state it accepted an optional cast attr. 3) In ODS's C++-generating code, expand partial support for `$_builder` access in `Attr::defaultValue` to full support. This enables access to the current `MlirContext` when constructing the default value (as is required when the default value consists of affine maps).
1 parent 771f6b9 commit f796bc6

File tree

9 files changed

+316
-32
lines changed

9 files changed

+316
-32
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
606606
let arguments = (ins
607607
Variadic<AnyType>:$inputs,
608608
Variadic<AnyShaped>:$outputs,
609-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
609+
DefaultValuedOptionalAttr<
610+
AffineMapArrayAttr,
611+
"MatmulOp::getDefaultIndexingMaps($_builder.getContext())"
612+
>:$indexing_maps,
610613
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
611614
);
612615
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -752,7 +755,8 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
752755
let arguments = (ins
753756
Variadic<AnyType>:$inputs,
754757
Variadic<AnyShaped>:$outputs,
755-
AffineMapArrayAttr:$indexing_maps
758+
AffineMapArrayAttr:$indexing_maps,
759+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
756760
);
757761
let results = (outs Variadic<AnyShaped>:$result_tensors);
758762
// NB: The only reason this op has a region - and it get populated at op build

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Attr<Pred condition, string summary = ""> :
5050

5151
// Default value for attribute.
5252
// Requires a constBuilderCall defined.
53+
//
54+
// Format: `$_builder` will be expanded to the relevant builder, e.g. to allow
55+
// access to the current context.
5356
string defaultValue = ?;
5457

5558
// The value type of this attribute. This corresponds to the mlir::Type that

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,11 +3666,6 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
36663666
}
36673667

36683668
void MatmulOp::print(OpAsmPrinter &p) {
3669-
SmallVector<StringRef, 3> elidedAttrs = {
3670-
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3671-
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3672-
elidedAttrs);
3673-
36743669
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
36753670
MatmulOp::getDefaultIndexingMaps(getContext()),
36763671
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -3680,6 +3675,11 @@ void MatmulOp::print(OpAsmPrinter &p) {
36803675
[&](Attribute attr) { p.printAttribute(attr); });
36813676
p << "]";
36823677
}
3678+
3679+
SmallVector<StringRef, 3> elidedAttrs = {
3680+
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3681+
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3682+
elidedAttrs);
36833683
}
36843684

36853685
/// Verify the user defined indexing maps.

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,49 @@ def __init__(
147147

148148

149149
generic = region_op(GenericOp_, terminator=YieldOp)
150+
151+
152+
def matmul(
153+
*ins: Union[Operation, OpView, Value],
154+
outs: Sequence[Union[Operation, OpView, Value]],
155+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
156+
cast: Optional[Union[TypeFn, Attribute]] = None,
157+
):
158+
ins = [_get_op_result_or_value(input) for input in ins]
159+
if len(outs) > 1:
160+
raise ValueError(f"{outs=} must have length 1.")
161+
init = _get_op_result_or_value(outs[0])
162+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
163+
164+
op = MatmulOp(
165+
result_tensors=result_types,
166+
inputs=ins,
167+
outputs=[init],
168+
indexing_maps=indexing_maps,
169+
cast=cast,
170+
)
171+
fill_builtin_region(op.operation)
172+
return op
173+
174+
175+
def contract(
176+
*ins: Union[Operation, OpView, Value],
177+
outs: Sequence[Union[Operation, OpView, Value]],
178+
indexing_maps: Sequence[AffineMapAttr],
179+
cast: Optional[Union[TypeFn, Attribute]] = None,
180+
):
181+
ins = [_get_op_result_or_value(input) for input in ins]
182+
if len(outs) > 1:
183+
raise ValueError(f"{outs=} must have length 1.")
184+
init = _get_op_result_or_value(outs[0])
185+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
186+
187+
op = ContractOp(
188+
result_tensors=result_types,
189+
inputs=ins,
190+
outputs=[init],
191+
indexing_maps=indexing_maps,
192+
cast=cast,
193+
)
194+
fill_builtin_region(op.operation)
195+
return op

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,7 +1269,7 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
12691269
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
12701270
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
12711271
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1272-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1272+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
12731273
// CHECK: return
12741274
// CHECK: }
12751275

@@ -1294,7 +1294,7 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
12941294
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
12951295
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
12961296
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1297-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1297+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
12981298
// CHECK: return
12991299
// CHECK: }
13001300

@@ -1315,6 +1315,7 @@ func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: m
13151315
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13161316
// CHECK-LABEL: func @matmul_bcast_a
13171317
// CHECK: linalg.matmul
1318+
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
13181319
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
13191320
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
13201321

@@ -1335,6 +1336,7 @@ func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %ar
13351336
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13361337
// CHECK-LABEL: func @matmul_bcast_a_dim1
13371338
// CHECK: linalg.matmul
1339+
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
13381340
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
13391341
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
13401342

@@ -1355,6 +1357,7 @@ func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: m
13551357
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13561358
// CHECK-LABEL: func @matmul_bcast_b
13571359
// CHECK: linalg.matmul
1360+
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
13581361
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
13591362
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
13601363

@@ -1376,7 +1379,7 @@ func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: m
13761379
// CHECK-LABEL: func.func @matmul_bcast_a_b(
13771380
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
13781381
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1379-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
1382+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
13801383
// CHECK: return
13811384
// CHECK: }
13821385

@@ -1397,6 +1400,7 @@ func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %ar
13971400
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13981401
// CHECK-LABEL: func @matmul_bcast_b_dim1
13991402
// CHECK: linalg.matmul
1403+
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
14001404
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
14011405
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
14021406

@@ -1420,7 +1424,7 @@ func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>,
14201424
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
14211425
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
14221426
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
1423-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1427+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>)
14241428
// CHECK: return
14251429
// CHECK: }
14261430

@@ -1444,7 +1448,7 @@ func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf3
14441448
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
14451449
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
14461450
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1447-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1451+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
14481452
// CHECK: return
14491453
// CHECK: }
14501454

@@ -1468,7 +1472,7 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
14681472
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
14691473
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
14701474
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1471-
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1475+
// CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
14721476
// CHECK: return
14731477
// CHECK: }
14741478

0 commit comments

Comments
 (0)