Skip to content

[mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern #84848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand All @@ -36,19 +38,17 @@ static Type matchContainerType(Type element, Type container) {
return element;
}

/// Lowering from a single vector::contractOp directly to the arm neon smmla
/// intrinsic. The shapes of the contract and intrinsic must match.
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
/// smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Value res = op.getAcc();

// Check index maps that represent M N K in contract.
auto indexingMaps = op.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
Expand All @@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
})) {
return failure();
}

// Check iterator types for contract.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() != 3 ||
Expand All @@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
iteratorTypes[2] != vector::IteratorType::reduction) {
return failure();
}

// Check the tile size by mapping the dimensions of the contract.
// Infer tile sizes from operands; Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
auto dimM = lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = lhsType.getDimSize(1);
if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {

// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
return failure();
}

// Check two extsi inputs Rhs Lhs for contract.
arith::ExtSIOp origLhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
arith::ExtSIOp origRhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
if (!origLhsExtOp || !origRhsExtOp) {
return failure();
}
Expand Down Expand Up @@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
return failure();
}

// Collapse to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
{16}, extsiLhs.getType().cast<ShapedType>().getElementType());
auto collapsedOutputType =
VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
extsiLhs.getLoc(), collapsedInputType, extsiLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
extsiRhs.getLoc(), collapsedInputType, extsiRhs);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
res.getLoc(), collapsedOutputType, res);

// Replace the contract with a neon op
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);

// Reshape output back to 2D
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
smmlaOp);
// Initial accumulator for the final result. This is the un-tiled result if
// tiling is done.
Value result = rewriter.create<arith::ConstantOp>(
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));

SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
SmallVector<int64_t> smmlaShape{2, 2, 8};
SmallVector<int64_t> loopOrder{0, 1, 2};
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {

// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
SmallVector<int64_t> operandShape =
applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffsets, operandShape, operandStrides);
};

// Extract tiled lhs, rhs, and acc
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);

// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
tiledLhs.getType().cast<ShapedType>().getNumElements(),
tiledLhs.getType().cast<ShapedType>().getElementType());
auto collapsedOutputType = VectorType::get(
{4}, tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);

// Insert contract op
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);

// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);

// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
tiledRes.getType().cast<ShapedType>().getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}

rewriter.replaceOp(op, result);
return success();
}
};
Expand Down
94 changes: 94 additions & 0 deletions mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,97 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
return %res : vector<2x2xi32>
}

// -----

// CHECK-LABEL: test_lower_vector_arm_neon_unroll
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
// CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
// CHECK-DAG: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
// CHECK-DAG: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
// CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32>
// CHECK-DAG: }
func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
%rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
return %res : vector<4x4xi32>
}

// -----

// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll(
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32>
// CHECK-DAG: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32>
// CHECK-DAG: %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32>
// CHECK-DAG: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32>
// CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
// CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32>
// CHECK-DAG: }
func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
%rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
return %res : vector<4x2xi32>
}

// -----

// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
// CHECK-DAG: %[[result:.*]] = vector.contract
func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
%lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
%rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
return %res : vector<4x4xi32>
}