Skip to content

Commit f7d61d4

Browse files
committed
[mlir][linalg] Extends 'linalg.matmul' named op to define broadcast and transpose semantic.
Goals: 1. To add syntax to matmul without changing any of the existing syntax expectations for current usage. matmul is still just matmul. 2. To expose broadcast and transpose semantics on the three matmul variations: matmul, batch_matmul and batch_reduce_matmul. Scope of this patch: To expose broadcast and transpose semantics on the 'matmul'. The broadcast and transpose semantic is as follows: By default 'linalg.matmul' behavior will remain as is.Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list must include all the maps if specified. Example Transpose: linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) Example Broadcast: linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2)>, // broadcast affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
1 parent 9014920 commit f7d61d4

File tree

14 files changed

+943
-182
lines changed

14 files changed

+943
-182
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,16 @@ def LinalgStructuredInterface
684684
return;
685685
}]
686686
>,
687+
InterfaceMethod<
688+
/*desc=*/[{
689+
Return true if the user has supplied an explicit indexing maps for this op.
690+
}],
691+
/*retTy=*/"bool",
692+
/*methodName=*/"hasUserDefinedMaps",
693+
/*args=*/(ins),
694+
/*methodBody=*/"",
695+
/*defaultImplementation=*/[{ return false; }]
696+
>,
687697
//===------------------------------------------------------------------===//
688698
// Linalg generalization hooks.
689699
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
10651065
- !ScalarExpression
10661066
scalar_arg: rhs
10671067
--- !LinalgOpConfig
1068-
metadata: !LinalgOpMetadata
1069-
name: matmul
1070-
cpp_class_name: MatmulOp
1071-
doc: |-
1072-
Performs a matrix multiplication of two 2D inputs.
1073-
1074-
Numeric casting is performed on the operands to the inner multiply, promoting
1075-
them to the same data type as the accumulator/output.
1076-
implements:
1077-
- LinalgContractionOpInterface
1078-
structured_op: !LinalgStructuredOpConfig
1079-
args:
1080-
- !LinalgOperandDefConfig
1081-
name: A
1082-
kind: input_tensor
1083-
type_var: T1
1084-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1085-
- !LinalgOperandDefConfig
1086-
name: B
1087-
kind: input_tensor
1088-
type_var: T2
1089-
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
1090-
- !LinalgOperandDefConfig
1091-
name: C
1092-
kind: output_tensor
1093-
type_var: U
1094-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1095-
- !LinalgOperandDefConfig
1096-
name: cast
1097-
kind: type_fn_attr
1098-
default_fn: cast_signed
1099-
indexing_maps: !LinalgIndexingMapsConfig
1100-
static_indexing_maps:
1101-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1102-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
1103-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1104-
iterator_types:
1105-
- parallel
1106-
- parallel
1107-
- reduction
1108-
assignments:
1109-
- !ScalarAssign
1110-
arg: C
1111-
value: !ScalarExpression
1112-
scalar_fn:
1113-
kind: binary
1114-
fn_name: add
1115-
operands:
1116-
- !ScalarExpression
1117-
scalar_arg: C
1118-
- !ScalarExpression
1119-
scalar_fn:
1120-
kind: binary
1121-
fn_name: mul
1122-
operands:
1123-
- !ScalarExpression
1124-
scalar_fn:
1125-
kind: type
1126-
attr_name: cast
1127-
type_var: U
1128-
operands:
1129-
- !ScalarExpression
1130-
scalar_arg: A
1131-
- !ScalarExpression
1132-
scalar_fn:
1133-
kind: type
1134-
attr_name: cast
1135-
type_var: U
1136-
operands:
1137-
- !ScalarExpression
1138-
scalar_arg: B
1139-
--- !LinalgOpConfig
11401068
metadata: !LinalgOpMetadata
11411069
name: quantized_matmul
11421070
cpp_class_name: QuantizedMatmulOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
535535
let hasCanonicalizer = 1;
536536
}
537537

538+
//===----------------------------------------------------------------------===//
539+
// Op definition for MatmulOp
540+
//===----------------------------------------------------------------------===//
541+
542+
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
543+
AttrSizedOperandSegments,
544+
LinalgContractionOpInterface]> {
545+
546+
let summary = [{
547+
Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
548+
}];
549+
let description = [{
550+
Numeric casting is performed on the operands to the inner multiply,
551+
promoting them to the same data type as the accumulator/output.
552+
553+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
554+
'indexing_maps' as shown below.This is a list attribute, so the list must include all
555+
the maps if specified.
556+
557+
Example Transpose:
558+
```
559+
linalg.matmul indexing_maps = [
560+
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
561+
affine_map<(d0, d1, d2) -> (d2, d1)>,
562+
affine_map<(d0, d1, d2) -> (d0, d1)>
563+
]
564+
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
565+
outs(%arg2: memref<3x7xf32>)
566+
```
567+
568+
Example Broadcast:
569+
```
570+
linalg.matmul indexing_maps = [
571+
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
572+
affine_map<(d0, d1, d2) -> (d2, d1)>,
573+
affine_map<(d0, d1, d2) -> (d0, d1)>
574+
]
575+
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
576+
outs(%arg2: memref<3x7xf32>)
577+
```
578+
579+
Example Broadcast and transpose:
580+
```
581+
linalg.matmul indexing_maps = [
582+
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
583+
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
584+
affine_map<(d0, d1, d2) -> (d0, d1)>
585+
]
586+
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
587+
}];
588+
589+
let arguments = (ins
590+
Variadic<AnyType>:$inputs,
591+
Variadic<AnyShaped>:$outputs,
592+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
593+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
594+
);
595+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
596+
let regions = (region AnyRegion:$region);
597+
598+
let skipDefaultBuilders = 1;
599+
let builders = [
600+
OpBuilder<
601+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
602+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
603+
[{
604+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
605+
attributes, MatmulOp::getRegionBuilder());
606+
}]>,
607+
OpBuilder<
608+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
609+
"ValueRange":$outputs,
610+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
611+
[{
612+
buildStructuredOp($_builder, $_state, resultTensorTypes,
613+
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
614+
}]>,
615+
OpBuilder<
616+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
617+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
618+
[{
619+
$_state.addOperands(operands);
620+
$_state.addAttributes(attributes);
621+
$_state.addTypes(resultTensorTypes);
622+
(void)$_state.addRegion();
623+
}]>,
624+
OpBuilder<
625+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
626+
"ValueRange":$outputs,
627+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
628+
[{
629+
$_state.addAttribute("cast", cast);
630+
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
631+
attributes, MatmulOp::getRegionBuilder());
632+
}]>
633+
634+
];
635+
let hasCustomAssemblyFormat = 1;
636+
let hasFolder = 1;
637+
let hasVerifier = 1;
638+
639+
let extraClassDeclaration = structuredOpsBaseDecls # [{
640+
SmallVector<utils::IteratorType> getIteratorTypesArray();
641+
642+
/// Implements the block region builder.
643+
static void regionBuilder(ImplicitLocOpBuilder &b,
644+
Block &block, ArrayRef<NamedAttribute> attrs);
645+
646+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
647+
SmallVector<AffineMap> getDefaultIndexingMaps();
648+
649+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
650+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
651+
652+
static std::function<void(ImplicitLocOpBuilder &,
653+
Block &, ArrayRef<NamedAttribute>)>
654+
getRegionBuilder() {
655+
return regionBuilder;
656+
}
657+
658+
::mlir::MutableOperandRange getDpsInitsMutable() {
659+
return getOutputsMutable();
660+
}
661+
662+
// Generic methods.
663+
static unsigned getNumRegionArgs();
664+
std::string getLibraryCallName();
665+
bool hasDynamicIndexingMaps();
666+
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
667+
/// user defined indexing maps are not equal to default map.
668+
bool hasUserDefinedMaps();
669+
}];
670+
}
671+
538672
//===----------------------------------------------------------------------===//
539673
// Named Linalg ops, implemented as a declarative configurations of generic ops.
540674
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/IR/AffineExpr.h"
1819
#include "mlir/IR/AffineExprVisitor.h"
1920
#include "mlir/IR/AffineMap.h"
21+
#include "mlir/IR/BuiltinTypeInterfaces.h"
22+
#include "mlir/IR/MLIRContext.h"
2023
#include "mlir/IR/TypeUtilities.h"
24+
#include "llvm/ADT/STLExtras.h"
2125
#include "llvm/ADT/SetOperations.h"
2226
#include "llvm/ADT/SmallBitVector.h"
2327
#include "llvm/ADT/SmallVector.h"
28+
#include "llvm/Support/Casting.h"
29+
#include "llvm/Support/raw_ostream.h"
2430
#include <algorithm>
31+
#include <optional>
2532

2633
using namespace mlir;
2734
using namespace mlir::linalg;
@@ -1142,7 +1149,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
11421149

11431150
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11441151
LinalgOp linalgOp = cast<LinalgOp>(op);
1145-
11461152
// Mixed tensor/buffer operands are not allowed.
11471153
if (!linalgOp.hasPureTensorSemantics() &&
11481154
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1162,6 +1168,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11621168
<< ") to be equal to the number of input/output operands ("
11631169
<< linalgOp->getNumOperands() << ")";
11641170

1171+
// Set this flag if this op has user defined maps. This is required to guard
1172+
// the below error condition which assume default indexing maps.
11651173
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
11661174
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
11671175

@@ -1178,13 +1186,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11781186
<< " dim(s) to match the number of loops";
11791187

11801188
int64_t rank = linalgOp.getRank(&opOperand);
1189+
11811190
if (indexingMap.getNumResults() != rank)
11821191
return op->emitOpError("expected operand rank (")
11831192
<< rank << ") to match the result rank of indexing_map #"
11841193
<< opOperand.getOperandNumber() << " ("
11851194
<< indexingMap.getNumResults() << ")";
11861195
}
1187-
11881196
SmallVector<unsigned> redDims;
11891197
linalgOp.getReductionDims(redDims);
11901198

@@ -1194,9 +1202,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11941202
// Check if given shapes match to inferred shapes.
11951203
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
11961204
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1197-
1198-
// Verify only static cases since we can't get exact dimension sizes and loop
1199-
// ranges for dynamic cases in this stage.
1205+
// Verify only static cases since we can't get exact dimension sizes and
1206+
// loop ranges for dynamic cases in this stage.
12001207
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
12011208
for (int64_t &range : endLoopRangeValues)
12021209
range -= 1;

0 commit comments

Comments
 (0)