-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern #84848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-neon @llvm/pr-subscribers-mlir Author: Kojo Acquah (KoolJBlack) ChangesThis patch updates Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. Full diff: https://github.com/llvm/llvm-project/pull/84848.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 47c84708f3c38b..acb03927b5d23e 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) {
return element;
}
-/// Lowering from a single vector::contractOp directly to the arm neon smmla
-/// intrinsic. The shapes of the contract and intrinsic must match.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
+/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
+/// unrolling. If no unrolling is necessary, a single smmla instruction is
+/// emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -45,10 +49,6 @@ class LowerContractionToSMMLAPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
- Value res = op.getAcc();
-
// Check index maps that represent M N K in contract.
auto indexingMaps = op.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
})) {
return failure();
}
-
// Check iterator types for contract.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
iteratorTypes[2] != vector::IteratorType::reduction) {
return failure();
}
-
- // Check the tile size by mapping the dimensions of the contract.
+ // Infer tile sizes from operands; Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
auto dimM = lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = lhsType.getDimSize(1);
- if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+
+ // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
+ // tiling.
+ if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
return failure();
}
// Check two extsi inputs Rhs Lhs for contract.
arith::ExtSIOp origLhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+ dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
arith::ExtSIOp origRhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+ dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
if (!origLhsExtOp || !origRhsExtOp) {
return failure();
}
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
return failure();
}
- // Collapse to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType =
- VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
- auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiLhs.getLoc(), collapsedInputType, extsiLhs);
- auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiRhs.getLoc(), collapsedInputType, extsiRhs);
- auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
- res.getLoc(), collapsedOutputType, res);
-
- // Replace the contract with a neon op
- auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
- op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
- collapsedRhs);
-
- // Reshape output back to 2D
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
- smmlaOp);
+ // Initial accumulator for the final result. This is the un-tiled result if
+ // tiling is done.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+ SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+ SmallVector<int64_t> smmlaShape{2, 2, 8};
+ SmallVector<int64_t> loopOrder{0, 1, 2};
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+ // Helper to compute the new shape of each operand and extract the slice.
+ auto extractOperand = [&](Value operand, AffineMap permutationMap,
+ ArrayRef<int64_t> operandOffsets) {
+ SmallVector<int64_t> operandShape =
+ applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+ SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+ return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffsets, operandShape, operandStrides);
+ };
+
+ // Extract tiled lhs, rhs, and acc
+ AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+ SmallVector<int64_t> lhsOffsets =
+ applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+ auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+ AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+ SmallVector<int64_t> rhsOffsets =
+ applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+ auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+ AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+ SmallVector<int64_t> accOffsets =
+ applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+ auto tiledAcc =
+ extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+
+ // Collapse tiled operands to 1D vectors required by smmla intrinsic
+ auto collapsedInputType = VectorType::get(
+ tiledLhs.getType().cast<ShapedType>().getNumElements(),
+ tiledLhs.getType().cast<ShapedType>().getElementType());
+ auto collapsedOutputType = VectorType::get(
+ {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+ auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+ // Insert contract op
+ auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+ op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+ collapsedRhs);
+
+ // Reshape output back to 2D
+ Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+
+ // Insert the tiled result back into the non tiled result of the
+ // contract op.
+ SmallVector<int64_t> strides(
+ tiledRes.getType().cast<ShapedType>().getRank(), 1);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledRes, result, accOffsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index cba7b00ba77a82..a4b873144b8b83 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -40,3 +40,53 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
return %res : vector<2x2xi32>
}
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK: return %[[VAL_39]] : vector<4x4xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
+ return %res : vector<4x4xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Looks great! Just a few comments.
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
6eff529
to
75e9296
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks!
75e9296
to
f6cb2ee
Compare
…MMLAPattern (llvm#84848) This patch updates `LowerContractionToSMMLAPattern` to unroll larger vector contracts into multiple smmla instructions. Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row.
This patch updates
LowerContractionToSMMLAPattern
to unroll larger vector contracts into multiple smmla instructions.Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2.
vector.extract_strided_slice
/vector.insert_strided_slice
divides the contract into tiles to be processed in a row.