Skip to content

Commit b5468fc

Browse files
committed
[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch_reduce_matmul.
This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm#115319 llvm#122275
1 parent ad704ff commit b5468fc

File tree

6 files changed

+762
-98
lines changed

6 files changed

+762
-98
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
17171717
- !ScalarExpression
17181718
scalar_arg: BZp
17191719
--- !LinalgOpConfig
1720-
metadata: !LinalgOpMetadata
1721-
name: batch_reduce_matmul
1722-
cpp_class_name: BatchReduceMatmulOp
1723-
doc: |-
1724-
Performs a batch-reduce matrix multiplication of two 3D inputs.
1725-
The partial multiplication results are reduced into a 2D output.
1726-
1727-
Numeric casting is performed on the operands to the inner multiply, promoting
1728-
them to the same data type as the accumulator/output.
1729-
implements:
1730-
- LinalgContractionOpInterface
1731-
structured_op: !LinalgStructuredOpConfig
1732-
args:
1733-
- !LinalgOperandDefConfig
1734-
name: A
1735-
kind: input_tensor
1736-
type_var: T1
1737-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1738-
- !LinalgOperandDefConfig
1739-
name: B
1740-
kind: input_tensor
1741-
type_var: T2
1742-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1743-
- !LinalgOperandDefConfig
1744-
name: C
1745-
kind: output_tensor
1746-
type_var: U
1747-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
1748-
indexing_maps: !LinalgIndexingMapsConfig
1749-
static_indexing_maps:
1750-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1751-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1752-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
1753-
iterator_types:
1754-
- reduction
1755-
- parallel
1756-
- parallel
1757-
- reduction
1758-
assignments:
1759-
- !ScalarAssign
1760-
arg: C
1761-
value: !ScalarExpression
1762-
scalar_fn:
1763-
kind: binary
1764-
fn_name: add
1765-
operands:
1766-
- !ScalarExpression
1767-
scalar_arg: C
1768-
- !ScalarExpression
1769-
scalar_fn:
1770-
kind: binary
1771-
fn_name: mul
1772-
operands:
1773-
- !ScalarExpression
1774-
scalar_fn:
1775-
kind: type
1776-
fn_name: cast_signed
1777-
type_var: U
1778-
operands:
1779-
- !ScalarExpression
1780-
scalar_arg: A
1781-
- !ScalarExpression
1782-
scalar_fn:
1783-
kind: type
1784-
fn_name: cast_signed
1785-
type_var: U
1786-
operands:
1787-
- !ScalarExpression
1788-
scalar_arg: B
1789-
--- !LinalgOpConfig
17901720
metadata: !LinalgOpMetadata
17911721
name: matvec
17921722
cpp_class_name: MatvecOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
10541054
}
10551055

10561056

1057+
//===----------------------------------------------------------------------===//
1058+
// Op definition for BatchReduceMatmulOp
1059+
//===----------------------------------------------------------------------===//
1060+
1061+
def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
1062+
AttrSizedOperandSegments,
1063+
LinalgContractionOpInterface]> {
1064+
1065+
let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
1066+
The partial multiplication results are reduced into a 2D output.}];
1067+
let description = [{
1068+
Numeric casting is performed on the operands to the inner multiply, promoting
1069+
them to the same data type as the accumulator/output.
1070+
1071+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
1072+
'indexing_maps' as shown below. This is a list attribute, so must include maps for all
1073+
arguments if specified.
1074+
1075+
Example Transpose:
1076+
```
1077+
linalg.batch_reduce_matmul indexing_maps = [
1078+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
1079+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1080+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1081+
]
1082+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
1083+
outs(%arg2: memref<3x7xf32>)
1084+
```
1085+
1086+
Example Broadcast:
1087+
```
1088+
linalg.batch_reduce_matmul indexing_maps = [
1089+
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
1090+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1091+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1092+
]
1093+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
1094+
outs(%arg2: memref<3x7xf32>)
1095+
```
1096+
1097+
Example Broadcast and Transpose:
1098+
```
1099+
linalg.batch_reduce_matmul indexing_maps = [
1100+
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
1101+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
1102+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1103+
]
1104+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
1105+
outs(%arg2: memref<3x7xf32>)
1106+
```
1107+
}];
1108+
1109+
let arguments = (ins
1110+
Variadic<AnyType>:$inputs,
1111+
Variadic<AnyShaped>:$outputs,
1112+
DefaultValuedOptionalAttr<
1113+
AffineMapArrayAttr,
1114+
"BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
1115+
>:$indexing_maps,
1116+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
1117+
);
1118+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1119+
let regions = (region AnyRegion:$region);
1120+
1121+
let skipDefaultBuilders = 1;
1122+
let builders = [
1123+
OpBuilder<
1124+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
1125+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1126+
[{
1127+
buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
1128+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1129+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1130+
}]>,
1131+
OpBuilder<
1132+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1133+
"ValueRange":$outputs,
1134+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1135+
[{
1136+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
1137+
inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
1138+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1139+
}]>,
1140+
OpBuilder<
1141+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1142+
"ValueRange":$outputs,
1143+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1144+
[{
1145+
$_state.addAttribute("cast", cast);
1146+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
1147+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1148+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1149+
}]>
1150+
1151+
];
1152+
let hasCustomAssemblyFormat = 1;
1153+
let hasFolder = 1;
1154+
let hasVerifier = 1;
1155+
1156+
let extraClassDeclaration = structuredOpsBaseDecls # [{
1157+
SmallVector<utils::IteratorType> getIteratorTypesArray();
1158+
1159+
/// Implements the block region builder.
1160+
static void regionBuilder(ImplicitLocOpBuilder &b,
1161+
Block &block, ArrayRef<NamedAttribute> attrs);
1162+
1163+
/// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
1164+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
1165+
1166+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
1167+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
1168+
1169+
static std::function<void(ImplicitLocOpBuilder &,
1170+
Block &, ArrayRef<NamedAttribute>)>
1171+
getRegionBuilder() {
1172+
return regionBuilder;
1173+
}
1174+
1175+
::mlir::MutableOperandRange getDpsInitsMutable() {
1176+
return getOutputsMutable();
1177+
}
1178+
1179+
// Generic methods.
1180+
static unsigned getNumRegionArgs();
1181+
std::string getLibraryCallName();
1182+
bool hasDynamicIndexingMaps() { return true; };
1183+
/// Returns true if the user defined indexing maps are not equal to default maps.
1184+
bool hasUserDefinedMaps();
1185+
}];
1186+
}
1187+
10571188
//===----------------------------------------------------------------------===//
10581189
// Named Linalg ops, implemented as a declarative configurations of generic ops.
10591190
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)