Skip to content

Commit 3ad0148

Browse files
shahidactrengolin
andauthored
[MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update failing obsolete OpDSL tests. (#115319)
The earlier PR(#104783) which introduces transpose and broadcast semantic to linalg.matmul was reverted due to two failing OpDSL test for linalg.matmul. Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL, these test started failing and needs to be removed/updated. This commit removes/updates the failing obsolete tests from below files. All other files were part of earlier PR and just cherry picked. "mlir/test/python/integration/dialects/linalg/opsrun.py" "mlir/test/python/integration/dialects/transform.py" --------- Co-authored-by: Renato Golin <[email protected]>
1 parent 21835ee commit 3ad0148

File tree

16 files changed

+959
-309
lines changed

16 files changed

+959
-309
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
708708
return;
709709
}]
710710
>,
711+
InterfaceMethod<
712+
/*desc=*/[{
713+
Return true if the user has supplied an explicit indexing maps for this op.
714+
}],
715+
/*retTy=*/"bool",
716+
/*methodName=*/"hasUserDefinedMaps",
717+
/*args=*/(ins),
718+
/*methodBody=*/"",
719+
/*defaultImplementation=*/[{ return false; }]
720+
>,
711721
//===------------------------------------------------------------------===//
712722
// Linalg generalization hooks.
713723
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
10651065
- !ScalarExpression
10661066
scalar_arg: rhs
10671067
--- !LinalgOpConfig
1068-
metadata: !LinalgOpMetadata
1069-
name: matmul
1070-
cpp_class_name: MatmulOp
1071-
doc: |-
1072-
Performs a matrix multiplication of two 2D inputs.
1073-
1074-
Numeric casting is performed on the operands to the inner multiply, promoting
1075-
them to the same data type as the accumulator/output.
1076-
implements:
1077-
- LinalgContractionOpInterface
1078-
structured_op: !LinalgStructuredOpConfig
1079-
args:
1080-
- !LinalgOperandDefConfig
1081-
name: A
1082-
kind: input_tensor
1083-
type_var: T1
1084-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1085-
- !LinalgOperandDefConfig
1086-
name: B
1087-
kind: input_tensor
1088-
type_var: T2
1089-
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
1090-
- !LinalgOperandDefConfig
1091-
name: C
1092-
kind: output_tensor
1093-
type_var: U
1094-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1095-
- !LinalgOperandDefConfig
1096-
name: cast
1097-
kind: type_fn_attr
1098-
default_fn: cast_signed
1099-
indexing_maps: !LinalgIndexingMapsConfig
1100-
static_indexing_maps:
1101-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1102-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
1103-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1104-
iterator_types:
1105-
- parallel
1106-
- parallel
1107-
- reduction
1108-
assignments:
1109-
- !ScalarAssign
1110-
arg: C
1111-
value: !ScalarExpression
1112-
scalar_fn:
1113-
kind: binary
1114-
fn_name: add
1115-
operands:
1116-
- !ScalarExpression
1117-
scalar_arg: C
1118-
- !ScalarExpression
1119-
scalar_fn:
1120-
kind: binary
1121-
fn_name: mul
1122-
operands:
1123-
- !ScalarExpression
1124-
scalar_fn:
1125-
kind: type
1126-
attr_name: cast
1127-
type_var: U
1128-
operands:
1129-
- !ScalarExpression
1130-
scalar_arg: A
1131-
- !ScalarExpression
1132-
scalar_fn:
1133-
kind: type
1134-
attr_name: cast
1135-
type_var: U
1136-
operands:
1137-
- !ScalarExpression
1138-
scalar_arg: B
1139-
--- !LinalgOpConfig
11401068
metadata: !LinalgOpMetadata
11411069
name: quantized_matmul
11421070
cpp_class_name: QuantizedMatmulOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
554554
let hasCanonicalizer = 1;
555555
}
556556

557+
//===----------------------------------------------------------------------===//
558+
// Op definition for MatmulOp
559+
//===----------------------------------------------------------------------===//
560+
561+
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
562+
AttrSizedOperandSegments,
563+
LinalgContractionOpInterface]> {
564+
565+
let summary = [{
566+
Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
567+
}];
568+
let description = [{
569+
Numeric casting is performed on the operands to the inner multiply,
570+
promoting them to the same data type as the accumulator/output.
571+
572+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
573+
'indexing_maps' as shown below.This is a list attribute, so the list must include all
574+
the maps if specified.
575+
576+
Example Transpose:
577+
```
578+
linalg.matmul indexing_maps = [
579+
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
580+
affine_map<(d0, d1, d2) -> (d2, d1)>,
581+
affine_map<(d0, d1, d2) -> (d0, d1)>
582+
]
583+
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
584+
outs(%arg2: memref<3x7xf32>)
585+
```
586+
587+
Example Broadcast:
588+
```
589+
linalg.matmul indexing_maps = [
590+
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
591+
affine_map<(d0, d1, d2) -> (d2, d1)>,
592+
affine_map<(d0, d1, d2) -> (d0, d1)>
593+
]
594+
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
595+
outs(%arg2: memref<3x7xf32>)
596+
```
597+
598+
Example Broadcast and transpose:
599+
```
600+
linalg.matmul indexing_maps = [
601+
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
602+
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
603+
affine_map<(d0, d1, d2) -> (d0, d1)>
604+
]
605+
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
606+
}];
607+
608+
let arguments = (ins
609+
Variadic<AnyType>:$inputs,
610+
Variadic<AnyShaped>:$outputs,
611+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
612+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
613+
);
614+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
615+
let regions = (region AnyRegion:$region);
616+
617+
let skipDefaultBuilders = 1;
618+
let builders = [
619+
OpBuilder<
620+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
621+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
622+
[{
623+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
624+
attributes, MatmulOp::getRegionBuilder());
625+
}]>,
626+
OpBuilder<
627+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
628+
"ValueRange":$outputs,
629+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
630+
[{
631+
buildStructuredOp($_builder, $_state, resultTensorTypes,
632+
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
633+
}]>,
634+
OpBuilder<
635+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
636+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
637+
[{
638+
$_state.addOperands(operands);
639+
$_state.addAttributes(attributes);
640+
$_state.addTypes(resultTensorTypes);
641+
(void)$_state.addRegion();
642+
}]>,
643+
OpBuilder<
644+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
645+
"ValueRange":$outputs,
646+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
647+
[{
648+
$_state.addAttribute("cast", cast);
649+
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
650+
attributes, MatmulOp::getRegionBuilder());
651+
}]>
652+
653+
];
654+
let hasCustomAssemblyFormat = 1;
655+
let hasFolder = 1;
656+
let hasVerifier = 1;
657+
658+
let extraClassDeclaration = structuredOpsBaseDecls # [{
659+
SmallVector<utils::IteratorType> getIteratorTypesArray();
660+
661+
/// Implements the block region builder.
662+
static void regionBuilder(ImplicitLocOpBuilder &b,
663+
Block &block, ArrayRef<NamedAttribute> attrs);
664+
665+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
666+
SmallVector<AffineMap> getDefaultIndexingMaps();
667+
668+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
669+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
670+
671+
static std::function<void(ImplicitLocOpBuilder &,
672+
Block &, ArrayRef<NamedAttribute>)>
673+
getRegionBuilder() {
674+
return regionBuilder;
675+
}
676+
677+
::mlir::MutableOperandRange getDpsInitsMutable() {
678+
return getOutputsMutable();
679+
}
680+
681+
// Generic methods.
682+
static unsigned getNumRegionArgs();
683+
std::string getLibraryCallName();
684+
bool hasDynamicIndexingMaps();
685+
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
686+
/// user defined indexing maps are not equal to default map.
687+
bool hasUserDefinedMaps();
688+
}];
689+
}
690+
557691
//===----------------------------------------------------------------------===//
558692
// Named Linalg ops, implemented as a declarative configurations of generic ops.
559693
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,21 @@
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/IR/AffineExpr.h"
1819
#include "mlir/IR/AffineExprVisitor.h"
1920
#include "mlir/IR/AffineMap.h"
21+
#include "mlir/IR/BuiltinTypeInterfaces.h"
22+
#include "mlir/IR/MLIRContext.h"
2023
#include "mlir/IR/TypeUtilities.h"
24+
#include "llvm/ADT/STLExtras.h"
2125
#include "llvm/ADT/SetOperations.h"
2226
#include "llvm/ADT/SmallBitVector.h"
2327
#include "llvm/ADT/SmallVector.h"
28+
#include "llvm/Support/Casting.h"
29+
#include "llvm/Support/raw_ostream.h"
2430
#include <algorithm>
2531
#include <numeric>
32+
#include <optional>
2633

2734
using namespace mlir;
2835
using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
12111218

12121219
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12131220
LinalgOp linalgOp = cast<LinalgOp>(op);
1214-
12151221
// Mixed tensor/buffer operands are not allowed.
12161222
if (!linalgOp.hasPureTensorSemantics() &&
12171223
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12311237
<< ") to be equal to the number of input/output operands ("
12321238
<< linalgOp->getNumOperands() << ")";
12331239

1240+
// Set this flag if this op has user defined maps. This is required to guard
1241+
// the below error condition which assume default indexing maps.
12341242
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
12351243
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
12361244

@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12471255
<< " dim(s) to match the number of loops";
12481256

12491257
int64_t rank = linalgOp.getRank(&opOperand);
1258+
12501259
if (indexingMap.getNumResults() != rank)
12511260
return op->emitOpError("expected operand rank (")
12521261
<< rank << ") to match the result rank of indexing_map #"
12531262
<< opOperand.getOperandNumber() << " ("
12541263
<< indexingMap.getNumResults() << ")";
12551264
}
1256-
12571265
SmallVector<unsigned> redDims;
12581266
linalgOp.getReductionDims(redDims);
12591267

@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
12631271
// Check if given shapes match to inferred shapes.
12641272
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
12651273
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1266-
1267-
// Verify only static cases since we can't get exact dimension sizes and loop
1268-
// ranges for dynamic cases in this stage.
1274+
// Verify only static cases since we can't get exact dimension sizes and
1275+
// loop ranges for dynamic cases in this stage.
12691276
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
12701277
for (int64_t &range : endLoopRangeValues)
12711278
range -= 1;

0 commit comments

Comments
 (0)