@@ -680,6 +680,142 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
680
680
}];
681
681
}
682
682
683
+ //===----------------------------------------------------------------------===//
684
+ // Contract op.
685
+ //===----------------------------------------------------------------------===//
686
+
687
+ def ContractOp : LinalgStructuredBase_Op<"contract", [
688
+ AttrSizedOperandSegments,
689
+ LinalgContractionOpInterface]> {
690
+ let summary = [{
691
+ Perform a contraction on two inputs, accumulating into the third.
692
+ }];
693
+ let description = [{
694
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
695
+ output `D` is given by
696
+
697
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
698
+
699
+ where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
700
+ identifiers - meant to range over valid indices - corresponding to the
701
+ results of the mandatory (projected permutation) `indexing_maps` for `A`,
702
+ `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
703
+ dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
704
+ dim identifiers).
705
+
706
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
707
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
708
+ each dim is inferred and is either:
709
+
710
+ - reduction: the dim is used to index into `A` and `B` but not `C`. Per the
711
+ above semantics, these dims will be contracted, i.e. reduced over.
712
+
713
+ - parallel: the dim is used to index into `C` and at least one of `A` and
714
+ `B`, and - deriving from matmul terminology - is either an "M-like" dim
715
+ (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
716
+ "batch"-dim (if used to index into `A`, `B`, and `C`).
717
+
718
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
719
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
720
+ `n` and `b` have parallel iteration-type) and gets represented as:
721
+
722
+ ```
723
+ %D = linalg.contract
724
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
725
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
726
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
727
+ ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
728
+ outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
729
+ ```
730
+
731
+ Note that by permuting dims in the `affine_map`s' results, accesses to
732
+ to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
733
+ broadcasts can be achieved through leaving out dims on either input operand.
734
+ For example, the following is a variant of batch-matmul with a transposition
735
+ applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
736
+
737
+ ```
738
+ linalg.contract
739
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
740
+ affine_map<(batch, m, n, k) -> (k, n)>,
741
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
742
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
743
+ outs(%C: memref<?x?x?xf32>)
744
+ ```
745
+
746
+ Numeric casting is performed on the operands to the inner multiplication,
747
+ promoting/truncating them to the same data type as the accumulator/output.
748
+
749
+ TODO: Allow control over the combining/accumulating op and possibly the
750
+ multiplication op.
751
+ }];
752
+
753
+ let arguments = (ins
754
+ Variadic<AnyType>:$inputs,
755
+ Variadic<AnyShaped>:$outputs,
756
+ AffineMapArrayAttr:$indexing_maps
757
+ );
758
+ let results = (outs Variadic<AnyShaped>:$result_tensors);
759
+ // NB: The only reason this op has a region - and it get populated at op build
760
+ // time - is that currently the LinalgOp interface exposes methods that
761
+ // assume a relevant region is available to be queried at any time.
762
+ let regions = (region SizedRegion<1>:$combiner);
763
+
764
+ let skipDefaultBuilders = 1;
765
+ let builders = [
766
+ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
767
+ "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
768
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
769
+ [{
770
+ $_state.addAttribute("indexing_maps", indexingMaps);
771
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
772
+ outputs, attributes, regionBuilder);
773
+ }]>,
774
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
775
+ "ArrayAttr":$indexingMaps,
776
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
777
+ [{
778
+ $_state.addAttribute("indexing_maps", indexingMaps);
779
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
780
+ attributes, regionBuilder);
781
+ }]>
782
+ ];
783
+ let hasCustomAssemblyFormat = 1;
784
+ let hasFolder = 1;
785
+ let hasVerifier = 1;
786
+
787
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
788
+ // Declare/implement functions necessary for LinalgStructuredInterface.
789
+
790
+ /// Infer iterator types for each dim in the domain of IndexingMaps.
791
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
792
+
793
+ /// IndexingMaps always depends on attr associated to current Op instance.
794
+ bool hasDynamicIndexingMaps() { return true; };
795
+ bool hasUserDefinedMaps() { return true; };
796
+
797
+ static unsigned getNumRegionArgs();
798
+
799
+ static void regionBuilder(ImplicitLocOpBuilder &b,
800
+ Block &block, ArrayRef<NamedAttribute> attrs);
801
+
802
+ static std::function<void(ImplicitLocOpBuilder &,
803
+ Block &, ArrayRef<NamedAttribute>)>
804
+ getRegionBuilder() {
805
+ return regionBuilder;
806
+ }
807
+
808
+ std::string getLibraryCallName() {
809
+ return "op_has_no_registered_library_name";
810
+ }
811
+
812
+ // Implement function necessary for DestinationStyleOpInterface.
813
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
814
+ return getOutputsMutable();
815
+ }
816
+ }];
817
+ }
818
+
683
819
//===----------------------------------------------------------------------===//
684
820
// Named Linalg ops, implemented as a declarative configurations of generic ops.
685
821
//===----------------------------------------------------------------------===//
0 commit comments