Skip to content

[mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME #69604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 31, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 19, 2023

This patch adds support for lowering masked outer products to SME. This is done in two stages. First, vector.outerproducts (both masked and non-masked) are rewritten to arm_sme.outerproducts. The arm_sme.outerproduct op is close to vector.outerproduct, but supports masking on the operands rather than the result. It also limits the cases it handles to things that could be (directly) lowered to SME.

This currently requires that the source of the mask is a vector.create_mask op. E.g.:

%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

Is rewritten to:

%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>

(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This patch adds support for lowering masked outer products to SME. This is done in two stages. First, vector.outerproducts (both masked and non-masked) are rewritten to arm_sme.outerproducts. The arm_sme.outerproduct op is close to vector.outerproduct, but supports masking on the operands rather than the result. It also limits the cases it handles to things that could be lowered to SME, but does not enforce that everything matches SME tiles at this level.

This currently requires that the source of the mask is a vector.create_mask op. E.g.:

%mask = vector.create_mask %dimA, %dimB : vector&lt;[4]x[4]xi1&gt;
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;
          } : vector&lt;[4]x[4]xi1&gt; -&gt; vector&lt;[4]x[4]xf32&gt;

Is rewritten to:

%maskA = vector.create_mask %dimA : vector&lt;[4]xi1&gt;
%maskB = vector.create_mask %dimB : vector&lt;[4]xi1&gt;
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;, vector&lt;[4]x[4]xf32&gt;

(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.


Patch is 36.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69604.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+108)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+76-1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+26-34)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+29)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+53)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+69-5)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+96)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+70)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..f60126e83603f47 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -79,6 +79,23 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
   let defaultValue = "TileSliceLayout::Horizontal";
 }
 
+def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
+  I32EnumAttrCase<"Add", 0, "add">,
+  I32EnumAttrCase<"Sub", 1, "sub">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies how to combine a newly produced value with the
+/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
+/// the functions that are valid for SME outer products.
+def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
+                                          "kind"> {
+  let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "CombiningKind::Add";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -507,4 +524,95 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
+  TypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
+    "!" # maskGetter # "() || std::equal_to<>()">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    TypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)",
+      "!getAcc() || std::equal_to<>()">
+  ]>
+{
+  let summary = "Vector outerproduct with optional fused add";
+
+  let description = [{
+    This op is based on `vector.outerproduct` with the extra conditions that:
+
+    * AXPY operations are not supported
+    * The only combining functions are "add" and "sub"
+    * Masking is performed on the inputs (rather than the output)
+
+    This is meant as an intermediate op for lowering `vector.outerproduct` to
+    SME. Types are not restricted to SVE/SME vectors at this level.
+
+    Example 1: Unmasked outerproduct (without accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 2: Unmasked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Masked outerproduct
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 4: Masked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    VectorOfRank<[1]>:$lhs, VectorOfRank<[1]>:$rhs,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$lhsMask,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$rhsMask,
+    Optional<VectorOfRank<[2]>>: $acc,
+    ArmSME_CombiningKindAttr:$kind);
+  let results = (outs VectorOfRank<[2]>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `kind` `` $kind
+      | `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict
+    `:` type($lhs) `,` type($rhs) `,` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return getLhs().getType(); }
+    VectorType getRhsType() { return getRhs().getType(); }
+    VectorType getResultType() { return getResult().getType(); }
+  }];
+}
+
 #endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..b7ec2603477d59e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,6 +427,80 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Lowers a masked `vector.outerproduct` to `arm_sme.outerproduct`.
+/// The 2-D mask of the `vector.outerproduct` (if from a `vector.create_mask`)
+/// is decomposed into two 1-D masks for the operands.
+///
+///  BEFORE:
+///  ```mlir
+///  %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///  %result = vector.mask %mask {
+///               vector.outerproduct %vecA, %vecB
+///                : vector<[4]xf32>, vector<[4]xf32>
+///            } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///  %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///  ```
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+    // AXPY operation not suited for SME.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
+
+    auto kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      return outerProductOp.emitError("unsupported kind");
+
+    Value lhsMask = {};
+    Value rhsMask = {};
+    Operation *rootOp = outerProductOp;
+    if (outerProductOp.isMasked()) {
+      auto maskingOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskingOp);
+      rootOp = maskingOp;
+
+      // Attempt to extract masks from vector.create_mask.
+      // TODO: Add support for other mask sources.
+      auto mask = maskingOp.getMask();
+      auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+
+      auto maskType = createMaskOp.getVectorType();
+      if (maskType.getRank() != 2)
+        return failure();
+
+      auto loc = outerProductOp.getLoc();
+
+      Value lhsMaskDim = createMaskOp.getOperand(0);
+      Value rhsMaskDim = createMaskOp.getOperand(1);
+
+      VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+      lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      lhsMaskDim);
+      rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      rhsMaskDim);
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -434,5 +508,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
                SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
+               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+               VectorOuterProductToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..0e5c996e518a472 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -458,13 +458,13 @@ struct MoveTileSliceToVectorArmSMELowering
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
-struct VectorOuterProductToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
-  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+struct OuterProductToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
+  using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp,
-                  vector::OuterProductOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
+                  arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
@@ -496,24 +496,13 @@ struct VectorOuterProductToArmSMELowering
       return true;
     };
 
-    auto resultVectorType = outerProductOp.getResultVectorType();
-    if (!isSupportedType(resultVectorType))
-      return outerProductOp.emitError("unsupported type");
-
-    vector::CombiningKind kind = outerProductOp.getKind();
-    if (kind != vector::CombiningKind::ADD)
-      // TODO: support subtract.
+    // TODO: Support CombiningKind::Sub for outer products.
+    if (outerProductOp.getKind() != CombiningKind::Add)
       return outerProductOp.emitError("unsupported kind");
 
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: support masking.
-      return outerProductOp.emitError("masking is currently unsupported");
-
-    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      // AXPY operation not suited for SME.
-      return failure();
+    auto resultVectorType = outerProductOp.getResultType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
 
     auto loc = outerProductOp.getLoc();
 
@@ -526,21 +515,24 @@ struct VectorOuterProductToArmSMELowering
     auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
         loc, rewriter.getIntegerType(elementWidth), acc);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy =
-        VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
-                        /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
-
     auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
 
+    Value lhsMask = outerProductOp.getLhsMask();
+    Value rhsMask = outerProductOp.getRhsMask();
+
+    if (!lhsMask || !rhsMask) {
+      auto predTy =
+          outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
+      Value allActiveMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predTy, true));
+      lhsMask = allActiveMask;
+      rhsMask = allActiveMask;
+    }
+
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(
-        loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
-        outerProductOp.getRhs());
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+                                               outerProductOp.getLhs(),
+                                               outerProductOp.getRhs());
 
     // Create `CastTileToVectorOp` to use as the output.
     rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
@@ -716,6 +708,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      VectorOuterProductToArmSMELowering, ZeroOpConversion,
+      OuterProductToArmSMELowering, ZeroOpConversion,
       VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..996c26834ae7fa6 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -97,3 +97,32 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+// -----
+
+// expected-note@+1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_mask_type(%vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<5xi1>, %maskB: vector<[2]xi1>)-> vector<3x[2]xf32>
+{
+  // expected-error@+1 {{use of value '%maskA' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x1xi16>
+{
+  // expected-error@+1 {{op failed to verify that result type is derived from `lhs` and `rhs`}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[2]x1xi16>
+  return %0 : vector<[2]x1xi16>
+}
+
+// -----
+
+// expected-note@+1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_acc_type(%vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<6x6xi32>) -> vector<7x6xi32>
+{
+  // expected-error@+1 {{use of value '%acc' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %0 : vector<7x6xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..6ee6d7a6149b940 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1135,3 +1135,56 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  return %result : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_masking(
+  %vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<3xi1>, %maskB: vector<[2]xi1>
+) -> vector<3x[2]xf32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %result : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_acc(
+  %vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<7x6xi32>
+) -> vector<7x6xi32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %result : vector<7x6xi32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  return %result : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_everything(
+  %vecA: vector<[4]xf16>, %vecB: vector<4xf16>, %acc: vector<[4]x4xf16>,
+  %maskA: vector<[4]xi1>, %maskB: vector<4xi1>
+) -> vector<[4]x4xf16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}})
+  // CHECK-SAME: : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB)
+              : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  return %result : vector<[4]x4xf16>
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 32f46d9fd817c9d..f615a9ef0231443 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -463,9 +463,75 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
+  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]...
[truncated]

@MacDue
Copy link
Member Author

MacDue commented Oct 19, 2023

This patch has no direct dependencies, but for use in lowering a matmul #69456 is required.

@MacDue
Copy link
Member Author

MacDue commented Oct 19, 2023

cc @cfRod

@MacDue MacDue force-pushed the arm_sme_masked_outer_product branch from 6606ebc to 852a83b Compare October 20, 2023 09:16
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! Mostly nits from me, but I need another round :)

vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

// Print the tile. Due to masking the result will be the top 2x3xf32 section.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, brilliant!

This op is based on `vector.outerproduct` with the extra conditions that:

* AXPY operations are not supported
* The only combining functions are "add" and "sub"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests lowering to LLVM that work for sub?

Copy link
Member Author

@MacDue MacDue Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've not yet implemented sub (or the intrinsics required for it).

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

}

/// An attribute that specifies how to combine a newly produced value with the
/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if it would make sense using the vector one and check if the kind is supported during verification?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try this, but including VectorAttributes.td within this file results in the generator creating duplicate definitions of attributes from the vector dialect (which fails to build).

@MacDue MacDue force-pushed the arm_sme_masked_outer_product branch from 852a83b to 21f0038 Compare October 25, 2023 16:01
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is obviously great :) I've made a few suggestions and raised a few questions, but nothing major.

I do think that this could be split into two patches though:

  • introduce arm_sme.outerproduct (and rewire the lowering for vector.outerproduct,
  • add support for masking.

I'm not going to push for that, as Diego has already accepted and in general this very well crafted - the noise is not too bad. Just a note for our futures selves :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for addressing my comments :)

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work Ben, I've left a few comments all minor, otherwise LGTM, cheers

// CHECK-LABEL: @vector_outerproduct_masked_f64
// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move below CHECK lines? (and below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This style is used in many places, and I find nicer to read :)

This patch adds support for lowering masked outer products to SME. This
is done in two stages. First, vector.outerproducts (both masked and
non-masked) are rewritten to arm_sme.outerproducts. The
arm_sme.outerproduct op is close to vector.outerproduct, but supports
masking on the operands rather than the result. It also limits the
cases it handles to things that could be lowered to SME, but does not
enforce that everything matches SME tiles at this level.

This currently requires that the source of the mask is a
vector.create_mask op. E.g.:

```mlir
%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
```
Is rewritten to:
```
%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
```
(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.
This also adds some no accumulator tests to test-outerproduct-f64.mlir
(and removes some now done TODOs, e.g. vector.splat support).
The arm_sme.outerproduct op has now been updated to match SME tile sizes,
which cleans things up a little. A few other fixups included.
@MacDue MacDue force-pushed the arm_sme_masked_outer_product branch from d8e98f1 to 452f4a1 Compare October 30, 2023 13:56
@MacDue MacDue merged commit e666295 into llvm:main Oct 31, 2023
@MacDue MacDue deleted the arm_sme_masked_outer_product branch October 31, 2023 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants