Skip to content

Commit f2bca9e

Browse files
authored
[MLIR][Linalg] Introduce broadcast/transpose semantic to batch_matmul (#122275)
Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>) outs (%arg2: memref<2x3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 #115319
1 parent f8e53a9 commit f2bca9e

File tree

12 files changed

+824
-100
lines changed

12 files changed

+824
-100
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,10 @@ def LinalgStructuredInterface
710710
>,
711711
InterfaceMethod<
712712
/*desc=*/[{
713-
Return true if the user has supplied an explicit indexing maps for this op.
713+
Returns true if the user has supplied explicit indexing maps that are
714+
different from default indexing maps for this op. Returns `false` otherwise.
715+
Note, if the user define maps that are identical to the default maps,
716+
this method returns `false`.
714717
}],
715718
/*retTy=*/"bool",
716719
/*methodName=*/"hasUserDefinedMaps",

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
14721472
- !ScalarExpression
14731473
scalar_arg: rhs
14741474
--- !LinalgOpConfig
1475-
metadata: !LinalgOpMetadata
1476-
name: batch_matmul
1477-
cpp_class_name: BatchMatmulOp
1478-
doc: |-
1479-
Performs a batched matrix multiplication of two 3D inputs.
1480-
1481-
Numeric casting is performed on the operands to the inner multiply, promoting
1482-
them to the same data type as the accumulator/output.
1483-
implements:
1484-
- LinalgContractionOpInterface
1485-
structured_op: !LinalgStructuredOpConfig
1486-
args:
1487-
- !LinalgOperandDefConfig
1488-
name: A
1489-
kind: input_tensor
1490-
type_var: T1
1491-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1492-
- !LinalgOperandDefConfig
1493-
name: B
1494-
kind: input_tensor
1495-
type_var: T2
1496-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1497-
- !LinalgOperandDefConfig
1498-
name: C
1499-
kind: output_tensor
1500-
type_var: U
1501-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
1502-
indexing_maps: !LinalgIndexingMapsConfig
1503-
static_indexing_maps:
1504-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1505-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1506-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
1507-
iterator_types:
1508-
- parallel
1509-
- parallel
1510-
- parallel
1511-
- reduction
1512-
assignments:
1513-
- !ScalarAssign
1514-
arg: C
1515-
value: !ScalarExpression
1516-
scalar_fn:
1517-
kind: binary
1518-
fn_name: add
1519-
operands:
1520-
- !ScalarExpression
1521-
scalar_arg: C
1522-
- !ScalarExpression
1523-
scalar_fn:
1524-
kind: binary
1525-
fn_name: mul
1526-
operands:
1527-
- !ScalarExpression
1528-
scalar_fn:
1529-
kind: type
1530-
fn_name: cast_signed
1531-
type_var: U
1532-
operands:
1533-
- !ScalarExpression
1534-
scalar_arg: A
1535-
- !ScalarExpression
1536-
scalar_fn:
1537-
kind: type
1538-
fn_name: cast_signed
1539-
type_var: U
1540-
operands:
1541-
- !ScalarExpression
1542-
scalar_arg: B
1543-
--- !LinalgOpConfig
15441475
metadata: !LinalgOpMetadata
15451476
name: batch_matmul_transpose_a
15461477
cpp_class_name: BatchMatmulTransposeAOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
674674
static unsigned getNumRegionArgs();
675675
std::string getLibraryCallName();
676676
bool hasDynamicIndexingMaps();
677-
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
678-
/// user defined indexing maps are not equal to default map.
677+
/// Returns true if the user defined indexing maps are not equal to default maps.
679678
bool hasUserDefinedMaps();
680679
}];
681680
}
@@ -816,6 +815,129 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
816815
}];
817816
}
818817

818+
//===----------------------------------------------------------------------===//
819+
// Op definition for BatchMatmulOp
820+
//===----------------------------------------------------------------------===//
821+
822+
def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
823+
/*extraInterfaces=*/[LinalgContractionOpInterface])> {
824+
825+
let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
826+
let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
827+
them to the same data type as the accumulator/output.
828+
829+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
830+
'indexing_maps' as shown below. This is a list attribute, so must include maps for all
831+
arguments if specified.
832+
833+
Example Transpose:
834+
```
835+
linalg.batch_matmul indexing_maps = [
836+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
837+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
838+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
839+
]
840+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
841+
outs(%arg2: memref<2x3x7xf32>)
842+
```
843+
844+
Example Broadcast:
845+
```
846+
linalg.batch_matmul indexing_maps = [
847+
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
848+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
849+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
850+
]
851+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
852+
outs(%arg2: memref<2x3x7xf32>)
853+
```
854+
855+
Example Broadcast and Transpose:
856+
```
857+
linalg.batch_matmul indexing_maps = [
858+
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
859+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
860+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
861+
]
862+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
863+
outs(%arg2: memref<2x3x7xf32>)
864+
```
865+
}];
866+
867+
let arguments = (ins
868+
Variadic<AnyType>:$inputs,
869+
Variadic<AnyShaped>:$outputs,
870+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
871+
);
872+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
873+
let regions = (region AnyRegion:$region);
874+
875+
let skipDefaultBuilders = 1;
876+
let builders = [
877+
OpBuilder<
878+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
879+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
880+
[{
881+
buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
882+
attributes, BatchMatmulOp::getRegionBuilder(),
883+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
884+
}]>,
885+
OpBuilder<
886+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
887+
"ValueRange":$outputs,
888+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
889+
[{
890+
buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
891+
inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
892+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
893+
}]>,
894+
OpBuilder<
895+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
896+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
897+
[{
898+
$_state.addOperands(operands);
899+
$_state.addAttributes(attributes);
900+
$_state.addTypes(resultTensorTypes);
901+
(void)$_state.addRegion(),
902+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
903+
}]>
904+
905+
];
906+
let hasCustomAssemblyFormat = 1;
907+
let hasFolder = 1;
908+
let hasVerifier = 1;
909+
910+
let extraClassDeclaration = structuredOpsBaseDecls # [{
911+
912+
SmallVector<utils::IteratorType> getIteratorTypesArray();
913+
static void regionBuilder(ImplicitLocOpBuilder &b,
914+
Block &block, ArrayRef<NamedAttribute> attrs);
915+
static std::function<void(ImplicitLocOpBuilder &,
916+
Block &, ArrayRef<NamedAttribute>)>
917+
getRegionBuilder() {
918+
return regionBuilder;
919+
}
920+
921+
/// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions.
922+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
923+
924+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
925+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
926+
927+
::mlir::MutableOperandRange getDpsInitsMutable() {
928+
return getOutputsMutable();
929+
}
930+
931+
// Generic methods.
932+
static unsigned getNumRegionArgs();
933+
bool hasDynamicIndexingMaps() { return true; }
934+
std::string getLibraryCallName();
935+
/// Returns true if the user defined indexing maps are not equal to default maps.
936+
bool hasUserDefinedMaps();
937+
}];
938+
}
939+
940+
819941
//===----------------------------------------------------------------------===//
820942
// Named Linalg ops, implemented as a declarative configurations of generic ops.
821943
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)