Skip to content

Commit 0d4efa2

Browse files
authored
[MLIR][Linalg] Introduce linalg.contract (#123618)
A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute. Supports the expected lowerings to linalg.generic and to vector.contract. Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589
1 parent 88e0014 commit 0d4efa2

File tree

10 files changed

+963
-35
lines changed

10 files changed

+963
-35
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,142 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
680680
}];
681681
}
682682

683+
//===----------------------------------------------------------------------===//
684+
// Contract op.
685+
//===----------------------------------------------------------------------===//
686+
687+
def ContractOp : LinalgStructuredBase_Op<"contract", [
688+
AttrSizedOperandSegments,
689+
LinalgContractionOpInterface]> {
690+
let summary = [{
691+
Perform a contraction on two inputs, accumulating into the third.
692+
}];
693+
let description = [{
694+
The semantics of contracting inputs `A` and `B` on top of `C` to produce
695+
output `D` is given by
696+
697+
`D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
698+
699+
where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
700+
identifiers - meant to range over valid indices - corresponding to the
701+
results of the mandatory (projected permutation) `indexing_maps` for `A`,
702+
`B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
703+
dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
704+
dim identifiers).
705+
706+
The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
707+
domain of each of the `affine_map`s. Like for einsums, the iteration type of
708+
each dim is inferred and is either:
709+
710+
- reduction: the dim is used to index into `A` and `B` but not `C`. Per the
711+
above semantics, these dims will be contracted, i.e. reduced over.
712+
713+
- parallel: the dim is used to index into `C` and at least one of `A` and
714+
`B`, and - deriving from matmul terminology - is either an "M-like" dim
715+
(if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
716+
"batch"-dim (if used to index into `A`, `B`, and `C`).
717+
718+
For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
719+
`H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
720+
`n` and `b` have parallel iteration-type) and gets represented as:
721+
722+
```
723+
%D = linalg.contract
724+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
725+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
726+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
727+
ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
728+
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
729+
```
730+
731+
Note that by permuting dims in the `affine_map`s' results, accesses to
732+
to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
733+
broadcasts can be achieved through leaving out dims on either input operand.
734+
For example, the following is a variant of batch-matmul with a transposition
735+
applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
736+
737+
```
738+
linalg.contract
739+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
740+
affine_map<(batch, m, n, k) -> (k, n)>,
741+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
742+
ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
743+
outs(%C: memref<?x?x?xf32>)
744+
```
745+
746+
Numeric casting is performed on the operands to the inner multiplication,
747+
promoting/truncating them to the same data type as the accumulator/output.
748+
749+
TODO: Allow control over the combining/accumulating op and possibly the
750+
multiplication op.
751+
}];
752+
753+
let arguments = (ins
754+
Variadic<AnyType>:$inputs,
755+
Variadic<AnyShaped>:$outputs,
756+
AffineMapArrayAttr:$indexing_maps
757+
);
758+
let results = (outs Variadic<AnyShaped>:$result_tensors);
759+
// NB: The only reason this op has a region - and it get populated at op build
760+
// time - is that currently the LinalgOp interface exposes methods that
761+
// assume a relevant region is available to be queried at any time.
762+
let regions = (region SizedRegion<1>:$combiner);
763+
764+
let skipDefaultBuilders = 1;
765+
let builders = [
766+
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
767+
"ValueRange":$outputs, "ArrayAttr":$indexingMaps,
768+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
769+
[{
770+
$_state.addAttribute("indexing_maps", indexingMaps);
771+
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
772+
outputs, attributes, regionBuilder);
773+
}]>,
774+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
775+
"ArrayAttr":$indexingMaps,
776+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
777+
[{
778+
$_state.addAttribute("indexing_maps", indexingMaps);
779+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
780+
attributes, regionBuilder);
781+
}]>
782+
];
783+
let hasCustomAssemblyFormat = 1;
784+
let hasFolder = 1;
785+
let hasVerifier = 1;
786+
787+
let extraClassDeclaration = structuredOpsBaseDecls # [{
788+
// Declare/implement functions necessary for LinalgStructuredInterface.
789+
790+
/// Infer iterator types for each dim in the domain of IndexingMaps.
791+
SmallVector<utils::IteratorType> getIteratorTypesArray();
792+
793+
/// IndexingMaps always depends on attr associated to current Op instance.
794+
bool hasDynamicIndexingMaps() { return true; };
795+
bool hasUserDefinedMaps() { return true; };
796+
797+
static unsigned getNumRegionArgs();
798+
799+
static void regionBuilder(ImplicitLocOpBuilder &b,
800+
Block &block, ArrayRef<NamedAttribute> attrs);
801+
802+
static std::function<void(ImplicitLocOpBuilder &,
803+
Block &, ArrayRef<NamedAttribute>)>
804+
getRegionBuilder() {
805+
return regionBuilder;
806+
}
807+
808+
std::string getLibraryCallName() {
809+
return "op_has_no_registered_library_name";
810+
}
811+
812+
// Implement function necessary for DestinationStyleOpInterface.
813+
::mlir::MutableOperandRange getDpsInitsMutable() {
814+
return getOutputsMutable();
815+
}
816+
}];
817+
}
818+
683819
//===----------------------------------------------------------------------===//
684820
// Named Linalg ops, implemented as a declarative configurations of generic ops.
685821
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)