Skip to content

Commit 2c55e73

Browse files
committed
[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch_reduce_matmul.
This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm#115319 llvm#122275
1 parent 6f6af49 commit 2c55e73

File tree

6 files changed

+761
-98
lines changed

6 files changed

+761
-98
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
17171717
- !ScalarExpression
17181718
scalar_arg: BZp
17191719
--- !LinalgOpConfig
1720-
metadata: !LinalgOpMetadata
1721-
name: batch_reduce_matmul
1722-
cpp_class_name: BatchReduceMatmulOp
1723-
doc: |-
1724-
Performs a batch-reduce matrix multiplication of two 3D inputs.
1725-
The partial multiplication results are reduced into a 2D output.
1726-
1727-
Numeric casting is performed on the operands to the inner multiply, promoting
1728-
them to the same data type as the accumulator/output.
1729-
implements:
1730-
- LinalgContractionOpInterface
1731-
structured_op: !LinalgStructuredOpConfig
1732-
args:
1733-
- !LinalgOperandDefConfig
1734-
name: A
1735-
kind: input_tensor
1736-
type_var: T1
1737-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1738-
- !LinalgOperandDefConfig
1739-
name: B
1740-
kind: input_tensor
1741-
type_var: T2
1742-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1743-
- !LinalgOperandDefConfig
1744-
name: C
1745-
kind: output_tensor
1746-
type_var: U
1747-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
1748-
indexing_maps: !LinalgIndexingMapsConfig
1749-
static_indexing_maps:
1750-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1751-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1752-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
1753-
iterator_types:
1754-
- reduction
1755-
- parallel
1756-
- parallel
1757-
- reduction
1758-
assignments:
1759-
- !ScalarAssign
1760-
arg: C
1761-
value: !ScalarExpression
1762-
scalar_fn:
1763-
kind: binary
1764-
fn_name: add
1765-
operands:
1766-
- !ScalarExpression
1767-
scalar_arg: C
1768-
- !ScalarExpression
1769-
scalar_fn:
1770-
kind: binary
1771-
fn_name: mul
1772-
operands:
1773-
- !ScalarExpression
1774-
scalar_fn:
1775-
kind: type
1776-
fn_name: cast_signed
1777-
type_var: U
1778-
operands:
1779-
- !ScalarExpression
1780-
scalar_arg: A
1781-
- !ScalarExpression
1782-
scalar_fn:
1783-
kind: type
1784-
fn_name: cast_signed
1785-
type_var: U
1786-
operands:
1787-
- !ScalarExpression
1788-
scalar_arg: B
1789-
--- !LinalgOpConfig
17901720
metadata: !LinalgOpMetadata
17911721
name: matvec
17921722
cpp_class_name: MatvecOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
10651065
}
10661066

10671067

1068+
//===----------------------------------------------------------------------===//
1069+
// Op definition for BatchReduceMatmulOp
1070+
//===----------------------------------------------------------------------===//
1071+
1072+
def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
1073+
AttrSizedOperandSegments,
1074+
LinalgContractionOpInterface]> {
1075+
1076+
let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
1077+
The partial multiplication results are reduced into a 2D output.}];
1078+
let description = [{
1079+
Numeric casting is performed on the operands to the inner multiply, promoting
1080+
them to the same data type as the accumulator/output.
1081+
1082+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
1083+
'indexing_maps' as shown below. This is a list attribute, so must include maps for all
1084+
arguments if specified.
1085+
1086+
Example Transpose:
1087+
```
1088+
linalg.batch_reduce_matmul indexing_maps = [
1089+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
1090+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1091+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1092+
]
1093+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
1094+
outs(%arg2: memref<3x7xf32>)
1095+
```
1096+
1097+
Example Broadcast:
1098+
```
1099+
linalg.batch_reduce_matmul indexing_maps = [
1100+
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
1101+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1102+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1103+
]
1104+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
1105+
outs(%arg2: memref<3x7xf32>)
1106+
```
1107+
1108+
Example Broadcast and Transpose:
1109+
```
1110+
linalg.batch_reduce_matmul indexing_maps = [
1111+
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
1112+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
1113+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1114+
]
1115+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
1116+
outs(%arg2: memref<3x7xf32>)
1117+
```
1118+
}];
1119+
1120+
let arguments = (ins
1121+
Variadic<AnyType>:$inputs,
1122+
Variadic<AnyShaped>:$outputs,
1123+
DefaultValuedOptionalAttr<
1124+
AffineMapArrayAttr,
1125+
"BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
1126+
>:$indexing_maps,
1127+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
1128+
);
1129+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1130+
let regions = (region AnyRegion:$region);
1131+
1132+
let skipDefaultBuilders = 1;
1133+
let builders = [
1134+
OpBuilder<
1135+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
1136+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1137+
[{
1138+
buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
1139+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1140+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1141+
}]>,
1142+
OpBuilder<
1143+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1144+
"ValueRange":$outputs,
1145+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1146+
[{
1147+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
1148+
inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
1149+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1150+
}]>,
1151+
OpBuilder<
1152+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1153+
"ValueRange":$outputs,
1154+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1155+
[{
1156+
$_state.addAttribute("cast", cast);
1157+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
1158+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1159+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1160+
}]>
1161+
1162+
];
1163+
let hasCustomAssemblyFormat = 1;
1164+
let hasFolder = 1;
1165+
let hasVerifier = 1;
1166+
1167+
let extraClassDeclaration = structuredOpsBaseDecls # [{
1168+
SmallVector<utils::IteratorType> getIteratorTypesArray();
1169+
1170+
/// Implements the block region builder.
1171+
static void regionBuilder(ImplicitLocOpBuilder &b,
1172+
Block &block, ArrayRef<NamedAttribute> attrs);
1173+
1174+
/// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
1175+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
1176+
1177+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
1178+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
1179+
1180+
static std::function<void(ImplicitLocOpBuilder &,
1181+
Block &, ArrayRef<NamedAttribute>)>
1182+
getRegionBuilder() {
1183+
return regionBuilder;
1184+
}
1185+
1186+
::mlir::MutableOperandRange getDpsInitsMutable() {
1187+
return getOutputsMutable();
1188+
}
1189+
1190+
// Generic methods.
1191+
static unsigned getNumRegionArgs();
1192+
std::string getLibraryCallName();
1193+
bool hasDynamicIndexingMaps() { return true; };
1194+
/// Returns true if the user defined indexing maps are not equal to default maps.
1195+
bool hasUserDefinedMaps();
1196+
}];
1197+
}
1198+
10681199
//===----------------------------------------------------------------------===//
10691200
// Named Linalg ops, implemented as a declarative configurations of generic ops.
10701201
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)