Skip to content

[mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern #84848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2024

Conversation

KoolJBlack
Copy link
Contributor

@KoolJBlack KoolJBlack commented Mar 11, 2024

This patch updates LowerContractionToSMMLAPattern to unroll larger vector contracts into multiple smmla instructions.

Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. vector.extract_strided_slice/vector.insert_strided_slice divides the contract into tiles to be processed in a row.

@KoolJBlack KoolJBlack marked this pull request as ready for review March 12, 2024 15:39
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-mlir-neon

@llvm/pr-subscribers-mlir

Author: Kojo Acquah (KoolJBlack)

Changes

This patch updates LowerVectorToArmNeonPattern to unroll larger vector contracts into multiple smmla instructions.

Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. vector.extract_strided_slice/vector.insert_strided_slice divides the contract into tiles to be processed in a row.


Full diff: https://github.com/llvm/llvm-project/pull/84848.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+80-32)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+50)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 47c84708f3c38b..acb03927b5d23e 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-/// Lowering from a single vector::contractOp directly to the arm neon smmla
-/// intrinsic. The shapes of the contract and intrinsic must match.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
+/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
+/// unrolling. If no unrolling is necessary, a single smmla instruction is
+/// emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -45,10 +49,6 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Value lhs = op.getLhs();
-    Value rhs = op.getRhs();
-    Value res = op.getAcc();
-
     // Check index maps that represent M N K in contract.
     auto indexingMaps = op.getIndexingMapsArray();
     if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
         })) {
       return failure();
     }
-
     // Check iterator types for contract.
     auto iteratorTypes = op.getIteratorTypesArray();
     if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
         iteratorTypes[2] != vector::IteratorType::reduction) {
       return failure();
     }
-
-    // Check the tile size by mapping the dimensions of the contract.
+    // Infer tile sizes from operands; Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
     auto dimM = lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = lhsType.getDimSize(1);
-    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+
+    // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
+    // tiling.
+    if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
       return failure();
     }
 
     // Check two extsi inputs Rhs Lhs for contract.
     arith::ExtSIOp origLhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
     arith::ExtSIOp origRhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
     if (!origLhsExtOp || !origRhsExtOp) {
       return failure();
     }
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Collapse to 1D vectors required by smmla intrinsic
-    auto collapsedInputType = VectorType::get(
-        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
-    auto collapsedOutputType =
-        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
-    auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
-    auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
-    auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
-        res.getLoc(), collapsedOutputType, res);
-
-    // Replace the contract with a neon op
-    auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
-        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
-        collapsedRhs);
-
-    // Reshape output back to 2D
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     smmlaOp);
+    // Initial accumulator for the final result. This is the un-tiled result if
+    // tiling is done.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> loopOrder{0, 1, 2};
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+      // Helper to compute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](Value operand, AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffsets) {
+        SmallVector<int64_t> operandShape =
+            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+        return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffsets, operandShape, operandStrides);
+      };
+
+      // Extract tiled lhs, rhs, and acc
+      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+      SmallVector<int64_t> lhsOffsets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      SmallVector<int64_t> rhsOffsets =
+          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      SmallVector<int64_t> accOffsets =
+          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledAcc =
+          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+
+      // Collapse tiled operands to 1D vectors required by smmla intrinsic
+      auto collapsedInputType = VectorType::get(
+          tiledLhs.getType().cast<ShapedType>().getNumElements(),
+          tiledLhs.getType().cast<ShapedType>().getElementType());
+      auto collapsedOutputType = VectorType::get(
+          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+      auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+      auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+      // Insert contract op
+      auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+          op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+          collapsedRhs);
+
+      // Reshape output back to 2D
+      Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+
+      // Insert the tiled result back into the non tiled result of the
+      // contract op.
+      SmallVector<int64_t> strides(
+          tiledRes.getType().cast<ShapedType>().getRank(), 1);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tiledRes, result, accOffsets, strides);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index cba7b00ba77a82..a4b873144b8b83 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -40,3 +40,53 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  return %[[VAL_39]] : vector<4x4xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}

@KoolJBlack KoolJBlack changed the title [mlir][ArmNeon] Implements unrolling patterns for LowerVectorToArmNeon LowerContractionToSMMLAPattern [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern Mar 14, 2024
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Looks great! Just a few comments.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks!

@KoolJBlack KoolJBlack merged commit fe84369 into llvm:main Mar 19, 2024
@KoolJBlack KoolJBlack deleted the arm_neon_unroll branch March 19, 2024 17:09
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…MMLAPattern (llvm#84848)

This patch updates `LowerContractionToSMMLAPattern` to unroll larger vector contracts into multiple smmla instructions. 

Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants